multi-agent-rlenv 3.5.5__tar.gz → 3.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/PKG-INFO +23 -6
  2. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/README.md +17 -2
  3. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/pyproject.toml +5 -6
  4. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/__init__.py +12 -3
  5. multi_agent_rlenv-3.6.0/src/marlenv/adapters/__init__.py +33 -0
  6. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/adapters/gym_adapter.py +6 -0
  7. multi_agent_rlenv-3.6.0/src/marlenv/catalog/__init__.py +26 -0
  8. multi_agent_rlenv-3.6.0/src/marlenv/catalog/deepsea.py +73 -0
  9. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/env_builder.py +1 -61
  10. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/models/spaces.py +1 -1
  11. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/utils/__init__.py +3 -0
  12. multi_agent_rlenv-3.6.0/src/marlenv/utils/import_placeholders.py +30 -0
  13. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/utils/schedule.py +16 -1
  14. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/tests/test_adapters.py +4 -107
  15. multi_agent_rlenv-3.6.0/tests/test_catalog.py +41 -0
  16. multi_agent_rlenv-3.6.0/tests/test_others.py +6 -0
  17. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/tests/test_schedules.py +15 -0
  18. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/tests/test_serialization.py +1 -83
  19. multi_agent_rlenv-3.5.5/src/marlenv/adapters/__init__.py +0 -42
  20. multi_agent_rlenv-3.5.5/src/marlenv/adapters/overcooked_adapter.py +0 -241
  21. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/.github/workflows/ci.yaml +0 -0
  22. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/.github/workflows/docs.yaml +0 -0
  23. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/.gitignore +0 -0
  24. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/LICENSE +0 -0
  25. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/adapters/pettingzoo_adapter.py +0 -0
  26. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/adapters/pymarl_adapter.py +0 -0
  27. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/adapters/smac_adapter.py +0 -0
  28. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/env_pool.py +0 -0
  29. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/exceptions.py +0 -0
  30. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/mock_env.py +0 -0
  31. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/models/__init__.py +0 -0
  32. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/models/env.py +0 -0
  33. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/models/episode.py +0 -0
  34. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/models/observation.py +0 -0
  35. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/models/state.py +0 -0
  36. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/models/step.py +0 -0
  37. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/models/transition.py +0 -0
  38. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/py.typed +0 -0
  39. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/utils/cached_property_collector.py +0 -0
  40. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/__init__.py +0 -0
  41. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/action_randomizer.py +0 -0
  42. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/agent_id_wrapper.py +0 -0
  43. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/available_actions_mask.py +0 -0
  44. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/available_actions_wrapper.py +0 -0
  45. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/blind_wrapper.py +0 -0
  46. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/centralised.py +0 -0
  47. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/delayed_rewards.py +0 -0
  48. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/last_action_wrapper.py +0 -0
  49. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/paddings.py +0 -0
  50. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/penalty_wrapper.py +0 -0
  51. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/potential_shaping.py +0 -0
  52. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/rlenv_wrapper.py +0 -0
  53. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/time_limit.py +0 -0
  54. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/video_recorder.py +0 -0
  55. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/tests/__init__.py +0 -0
  56. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/tests/test_episode.py +0 -0
  57. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/tests/test_models.py +0 -0
  58. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/tests/test_pool.py +0 -0
  59. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/tests/test_spaces.py +0 -0
  60. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/tests/test_wrappers.py +0 -0
  61. {multi_agent_rlenv-3.5.5 → multi_agent_rlenv-3.6.0}/tests/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.5.5
3
+ Version: 3.6.0
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
@@ -13,7 +13,8 @@ Requires-Dist: opencv-python>=4.0
13
13
  Requires-Dist: typing-extensions>=4.0
14
14
  Provides-Extra: all
15
15
  Requires-Dist: gymnasium>0.29.1; extra == 'all'
16
- Requires-Dist: overcooked-ai; extra == 'all'
16
+ Requires-Dist: laser-learning-environment>=2.6.1; extra == 'all'
17
+ Requires-Dist: overcooked>=0.1.0; extra == 'all'
17
18
  Requires-Dist: pettingzoo>=1.20; extra == 'all'
18
19
  Requires-Dist: pymunk>=6.0; extra == 'all'
19
20
  Requires-Dist: pysc2; extra == 'all'
@@ -22,9 +23,10 @@ Requires-Dist: smac; extra == 'all'
22
23
  Requires-Dist: torch>=2.0; extra == 'all'
23
24
  Provides-Extra: gym
24
25
  Requires-Dist: gymnasium>=0.29.1; extra == 'gym'
26
+ Provides-Extra: lle
27
+ Requires-Dist: laser-learning-environment>=2.6.1; extra == 'lle'
25
28
  Provides-Extra: overcooked
26
- Requires-Dist: overcooked-ai>=1.1.0; extra == 'overcooked'
27
- Requires-Dist: scipy>=1.10; extra == 'overcooked'
29
+ Requires-Dist: overcooked>=0.1.0; extra == 'overcooked'
28
30
  Provides-Extra: pettingzoo
29
31
  Requires-Dist: pettingzoo>=1.20; extra == 'pettingzoo'
30
32
  Requires-Dist: pymunk>=6.0; extra == 'pettingzoo'
@@ -62,9 +64,24 @@ $ pip install marlenv[smac] # Install SMAC
62
64
  $ pip install marlenv[gym,smac] # Install Gym & smac support
63
65
  ```
64
66
 
67
+ ## Using the `marlenv` environment catalog
68
+ Some environments are registered in the `marlenv` and can be easily instantiated via its catalog.
69
+
70
+ ```python
71
+ from marlenv import catalog
72
+
73
+ env1 = catalog.Overcooked.from_layout("scenario4")
74
+ env2 = catalog.LLE.level(6)
75
+ env3 = catalog.DeepSea(mex_depth=5)
76
+ ```
77
+ Note that using the catalog requires the corresponding environment package to be installed. For instance you need to install the `laser-learning-environment` package to use `catalog.LLE`, which can be done by using the corresponding feature when at installation as shown below.
78
+ ```bash
79
+ pip install multi-agent-rlenv[lle]
80
+ ```
81
+
65
82
 
66
83
  ## Using `marlenv` with existing libraries
67
- `marlenv` unifies multiple popular libraries under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
84
+ `marlenv` provides adapters from most popular libraries to unify them under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
68
85
 
69
86
  ```python
70
87
  import marlenv
@@ -74,7 +91,7 @@ gym_env = marlenv.make("CartPole-v1", seed=25)
74
91
 
75
92
  # You can seemlessly instanciate a SMAC environment and directly pass your required arguments
76
93
  from marlenv.adapters import SMAC
77
- smac_env = env2 = SMAC("3m", debug=True, difficulty="9")
94
+ smac_env = SMAC("3m", debug=True, difficulty="9")
78
95
 
79
96
  # pettingzoo is also supported
80
97
  from pettingzoo.sisl import pursuit_v4
@@ -24,9 +24,24 @@ $ pip install marlenv[smac] # Install SMAC
24
24
  $ pip install marlenv[gym,smac] # Install Gym & smac support
25
25
  ```
26
26
 
27
+ ## Using the `marlenv` environment catalog
28
+ Some environments are registered in the `marlenv` and can be easily instantiated via its catalog.
29
+
30
+ ```python
31
+ from marlenv import catalog
32
+
33
+ env1 = catalog.Overcooked.from_layout("scenario4")
34
+ env2 = catalog.LLE.level(6)
35
+ env3 = catalog.DeepSea(mex_depth=5)
36
+ ```
37
+ Note that using the catalog requires the corresponding environment package to be installed. For instance you need to install the `laser-learning-environment` package to use `catalog.LLE`, which can be done by using the corresponding feature when at installation as shown below.
38
+ ```bash
39
+ pip install multi-agent-rlenv[lle]
40
+ ```
41
+
27
42
 
28
43
  ## Using `marlenv` with existing libraries
29
- `marlenv` unifies multiple popular libraries under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
44
+ `marlenv` provides adapters from most popular libraries to unify them under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
30
45
 
31
46
  ```python
32
47
  import marlenv
@@ -36,7 +51,7 @@ gym_env = marlenv.make("CartPole-v1", seed=25)
36
51
 
37
52
  # You can seemlessly instanciate a SMAC environment and directly pass your required arguments
38
53
  from marlenv.adapters import SMAC
39
- smac_env = env2 = SMAC("3m", debug=True, difficulty="9")
54
+ smac_env = SMAC("3m", debug=True, difficulty="9")
40
55
 
41
56
  # pettingzoo is also supported
42
57
  from pettingzoo.sisl import pursuit_v4
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "multi-agent-rlenv"
3
- dynamic = ["version"]
3
+ version = "3.6.0"
4
4
  description = "A strongly typed Multi-Agent Reinforcement Learning framework"
5
5
  authors = [
6
6
  { "name" = "Yannick Molinghen", "email" = "yannick.molinghen@ulb.be" },
@@ -19,14 +19,16 @@ dependencies = ["numpy>=2.0.0", "opencv-python>=4.0", "typing_extensions>=4.0"]
19
19
  gym = ["gymnasium>=0.29.1"]
20
20
  smac = ["smac", "pysc2"]
21
21
  pettingzoo = ["pettingzoo>=1.20", "pymunk>=6.0", "scipy>=1.10"]
22
- overcooked = ["overcooked-ai>=1.1.0", "scipy>=1.10"]
22
+ overcooked = ["overcooked>=0.1.0"]
23
+ lle = ["laser-learning-environment>=2.6.1"]
23
24
  torch = ["torch>=2.0"]
24
25
  all = [
25
26
  "gymnasium>0.29.1",
26
27
  "pettingzoo>=1.20",
27
- "overcooked-ai",
28
+ "overcooked>=0.1.0",
28
29
  "smac",
29
30
  "pysc2",
31
+ "laser-learning-environment>=2.6.1",
30
32
  "pymunk>=6.0",
31
33
  "scipy>=1.10",
32
34
  "torch>=2.0",
@@ -42,9 +44,6 @@ line-length = 140
42
44
  [tool.hatch.build.targets.wheel]
43
45
  packages = ["src/marlenv"]
44
46
 
45
- [tool.hatch]
46
- version = { "path" = "src/marlenv/__init__.py", "attr" = "__version__" }
47
-
48
47
 
49
48
  [tool.pytest.ini_options]
50
49
  testpaths = ["tests"]
@@ -62,7 +62,13 @@ print(env.extras_shape) # (1, )
62
62
  If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
63
63
  """
64
64
 
65
- __version__ = "3.5.5"
65
+ from importlib.metadata import version, PackageNotFoundError
66
+
67
+ try:
68
+ __version__ = version("overcooked")
69
+ except PackageNotFoundError:
70
+ __version__ = "0.0.0" # fallback pratique en dev/CI
71
+
66
72
 
67
73
  from . import models
68
74
  from .models import (
@@ -82,16 +88,19 @@ from .models import (
82
88
 
83
89
  from . import wrappers
84
90
  from . import adapters
85
- from .env_builder import make, Builder
91
+ from .env_builder import Builder
86
92
  from .wrappers import RLEnvWrapper
87
93
  from .mock_env import DiscreteMockEnv, DiscreteMOMockEnv
94
+ from . import catalog
95
+ from .adapters import make
88
96
 
89
97
  __all__ = [
90
98
  "models",
99
+ "make",
100
+ "catalog",
91
101
  "wrappers",
92
102
  "adapters",
93
103
  "spaces",
94
- "make",
95
104
  "Builder",
96
105
  "MARLEnv",
97
106
  "Step",
@@ -0,0 +1,33 @@
1
+ from importlib.util import find_spec
2
+ from .pymarl_adapter import PymarlAdapter
3
+ from marlenv.utils import DummyClass, dummy_function
4
+
5
+ HAS_GYM = find_spec("gymnasium") is not None
6
+ if HAS_GYM:
7
+ from .gym_adapter import Gym, make
8
+ else:
9
+ Gym = DummyClass("gymnasium")
10
+ make = dummy_function("gymnasium")
11
+
12
+ HAS_PETTINGZOO = find_spec("pettingzoo") is not None
13
+ if HAS_PETTINGZOO:
14
+ from .pettingzoo_adapter import PettingZoo
15
+ else:
16
+ PettingZoo = DummyClass("pettingzoo")
17
+
18
+ HAS_SMAC = find_spec("smac") is not None
19
+ if HAS_SMAC:
20
+ from .smac_adapter import SMAC
21
+ else:
22
+ SMAC = DummyClass("smac", "https://github.com/oxwhirl/smac.git")
23
+
24
+ __all__ = [
25
+ "PymarlAdapter",
26
+ "Gym",
27
+ "make",
28
+ "PettingZoo",
29
+ "SMAC",
30
+ "HAS_GYM",
31
+ "HAS_PETTINGZOO",
32
+ "HAS_SMAC",
33
+ ]
@@ -78,3 +78,9 @@ class Gym(MARLEnv[Space]):
78
78
 
79
79
  def seed(self, seed_value: int):
80
80
  self._gym_env.reset(seed=seed_value)
81
+
82
+
83
+ def make(env_id: str, **kwargs):
84
+ """Make an RLEnv from str (Gym) or PettingZoo"""
85
+ gym_env = gym.make(env_id, render_mode="rgb_array", **kwargs)
86
+ return Gym(gym_env)
@@ -0,0 +1,26 @@
1
+ from importlib.util import find_spec
2
+ from ..utils.import_placeholders import DummyClass
3
+ from marlenv.adapters import SMAC
4
+ from .deepsea import DeepSea
5
+
6
+
7
+ HAS_LLE = find_spec("lle") is not None
8
+ if HAS_LLE:
9
+ from lle import LLE # pyright: ignore[reportMissingImports]
10
+ else:
11
+ LLE = DummyClass("lle", "laser-learning-environment")
12
+
13
+ HAS_OVERCOOKED = find_spec("overcooked") is not None
14
+ if HAS_OVERCOOKED:
15
+ from overcooked import Overcooked # pyright: ignore[reportMissingImports]
16
+ else:
17
+ Overcooked = DummyClass("overcooked", "overcooked")
18
+
19
+ __all__ = [
20
+ "Overcooked",
21
+ "SMAC",
22
+ "LLE",
23
+ "DeepSea",
24
+ "HAS_LLE",
25
+ "HAS_OVERCOOKED",
26
+ ]
@@ -0,0 +1,73 @@
1
+ from typing import Sequence
2
+ import numpy as np
3
+ from marlenv import MARLEnv, MultiDiscreteSpace, DiscreteSpace, Observation, State, Step
4
+ from dataclasses import dataclass
5
+
6
+
7
+ LEFT = 0
8
+ RIGHT = 1
9
+
10
+
11
+ @dataclass
12
+ class DeepSea(MARLEnv[MultiDiscreteSpace]):
13
+ """
14
+ Deep Sea single-agent environment to test for deep exploration. The probability of reaching the goal state under random exploration is 2^(-max_depth).
15
+
16
+ The agent explores a 2D grid where the bottom-right corner (max_depth, max_depth) is the goal and is the only state to yield a reward.
17
+ The agent starts in the top-left corner (0, 0).
18
+ The agent has two actions: left or right, and taking an action makes the agent dive one row deeper. The agent can not go beyond the grid boundaries.
19
+ Going right gives a penalty of (0.01 / max_depth).
20
+ """
21
+
22
+ max_depth: int
23
+
24
+ def __init__(self, max_depth: int):
25
+ super().__init__(
26
+ n_agents=1,
27
+ action_space=DiscreteSpace(size=2, labels=["left", "right"]).repeat(1),
28
+ observation_shape=(2,),
29
+ state_shape=(2,),
30
+ )
31
+ self.max_depth = max_depth
32
+ self._row = 0
33
+ self._col = 0
34
+ self._step_right_penalty = -0.01 / self.max_depth
35
+
36
+ def get_observation(self) -> Observation:
37
+ return Observation(np.array([self._row, self._col], dtype=np.float32), self.available_actions())
38
+
39
+ def get_state(self) -> State:
40
+ return State(np.array([self._row, self._col], dtype=np.float32))
41
+
42
+ def reset(self):
43
+ self._row = 0
44
+ self._col = 0
45
+ return self.get_observation(), self.get_state()
46
+
47
+ def step(self, action: Sequence[int]):
48
+ self._row += 1
49
+ if action[0] == LEFT:
50
+ self._col -= 1
51
+ else:
52
+ self._col += 1
53
+ self._col = max(0, self._col)
54
+ if action[0] == RIGHT:
55
+ if self._row == self.max_depth:
56
+ reward = 1.0
57
+ else:
58
+ reward = self._step_right_penalty
59
+ else:
60
+ reward = 0.0
61
+ return Step(
62
+ self.get_observation(),
63
+ self.get_state(),
64
+ reward,
65
+ done=self._row == self.max_depth,
66
+ )
67
+
68
+ def set_state(self, state: State):
69
+ self._row, self._col = state.data
70
+
71
+ @property
72
+ def agent_state_size(self):
73
+ return 2
@@ -1,73 +1,13 @@
1
1
  from dataclasses import dataclass
2
- from typing import Generic, Literal, Optional, TypeVar, overload
2
+ from typing import Generic, Literal, Optional, TypeVar
3
3
  import numpy as np
4
4
  import numpy.typing as npt
5
5
 
6
6
  from . import wrappers
7
- from marlenv import adapters
8
7
  from .models import Space, MARLEnv
9
8
 
10
9
  AS = TypeVar("AS", bound=Space)
11
10
 
12
- if adapters.HAS_PETTINGZOO:
13
- from .adapters import PettingZoo
14
- from pettingzoo import ParallelEnv
15
-
16
- @overload
17
- def make(env: ParallelEnv) -> PettingZoo: ...
18
-
19
-
20
- if adapters.HAS_GYM:
21
- from .adapters import Gym
22
- from gymnasium import Env
23
- import gymnasium
24
-
25
- @overload
26
- def make(env: Env) -> Gym: ...
27
-
28
- @overload
29
- def make(env: str, **kwargs) -> Gym:
30
- """
31
- Make an RLEnv from the `gymnasium` registry (e.g: "CartPole-v1").
32
- """
33
-
34
-
35
- if adapters.HAS_SMAC:
36
- from .adapters import SMAC
37
- from smac.env import StarCraft2Env
38
-
39
- @overload
40
- def make(env: StarCraft2Env) -> SMAC: ...
41
-
42
-
43
- if adapters.HAS_OVERCOOKED:
44
- from .adapters import Overcooked
45
- from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
46
-
47
- @overload
48
- def make(env: OvercookedEnv) -> Overcooked: ...
49
-
50
-
51
- def make(env, **kwargs):
52
- """Make an RLEnv from str (Gym) or PettingZoo"""
53
- match env:
54
- case MARLEnv():
55
- return env
56
- case str(env_id):
57
- if adapters.HAS_GYM:
58
- gym_env = gymnasium.make(env_id, render_mode="rgb_array", **kwargs)
59
- return Gym(gym_env)
60
-
61
- if adapters.HAS_PETTINGZOO and isinstance(env, ParallelEnv):
62
- return PettingZoo(env) # type: ignore
63
- if adapters.HAS_SMAC and isinstance(env, StarCraft2Env):
64
- return SMAC(env)
65
- if adapters.HAS_OVERCOOKED and isinstance(env, OvercookedEnv):
66
- return Overcooked(env) # type: ignore
67
- if adapters.HAS_GYM and isinstance(env, Env):
68
- return Gym(env)
69
- raise ValueError(f"Unknown environment type: {type(env)}")
70
-
71
11
 
72
12
  @dataclass
73
13
  class Builder(Generic[AS]):
@@ -105,7 +105,7 @@ class MultiDiscreteSpace(Space):
105
105
  def sample(self, mask: Optional[npt.NDArray[np.bool] | list[npt.NDArray[np.bool]]] = None):
106
106
  if mask is None:
107
107
  return np.array([space.sample() for space in self.spaces], dtype=np.int32)
108
- return np.array([space.sample(mask=mask) for mask, space in zip(mask, self.spaces)], dtype=np.int32)
108
+ return np.array([space.sample(mask=mask) for mask, space in zip(mask, self.spaces)], dtype=np.int32) # type: ignore
109
109
 
110
110
  def __eq__(self, value: object) -> bool:
111
111
  if not isinstance(value, MultiDiscreteSpace):
@@ -1,5 +1,6 @@
1
1
  from .cached_property_collector import CachedPropertyCollector, CachedPropertyInvalidator
2
2
  from .schedule import ExpSchedule, LinearSchedule, MultiSchedule, RoundedSchedule, Schedule
3
+ from .import_placeholders import DummyClass, dummy_function
3
4
 
4
5
  __all__ = [
5
6
  "Schedule",
@@ -9,4 +10,6 @@ __all__ = [
9
10
  "RoundedSchedule",
10
11
  "CachedPropertyCollector",
11
12
  "CachedPropertyInvalidator",
13
+ "DummyClass",
14
+ "dummy_function",
12
15
  ]
@@ -0,0 +1,30 @@
1
+ from typing import Optional, Any
2
+
3
+
4
+ class DummyClass:
5
+ def __init__(self, module_name: str, package_name: Optional[str] = None):
6
+ self.module_name = module_name
7
+ if package_name is None:
8
+ self.package_name = module_name
9
+ else:
10
+ self.package_name = package_name
11
+
12
+ def _raise_error(self):
13
+ raise ImportError(
14
+ f"The optional dependency `{self.module_name}` is not installed.\nInstall the `{self.package_name}` package (e.g. pip install {self.package_name})."
15
+ )
16
+
17
+ def __getattr__(self, _):
18
+ self._raise_error()
19
+
20
+ def __call__(self, *args, **kwargs):
21
+ self._raise_error()
22
+
23
+
24
+ def dummy_function(module_name: str, package_name: Optional[str] = None):
25
+ dummy = DummyClass(module_name, package_name)
26
+
27
+ def fail(*args, **kwargs) -> Any:
28
+ dummy()
29
+
30
+ return fail
@@ -1,6 +1,6 @@
1
1
  from abc import abstractmethod
2
2
  from dataclasses import dataclass
3
- from typing import Callable, Optional, TypeVar
3
+ from typing import Any, Callable, Optional, TypeVar
4
4
 
5
5
  T = TypeVar("T")
6
6
 
@@ -142,6 +142,21 @@ class Schedule:
142
142
  def __int__(self) -> int:
143
143
  return int(self.value)
144
144
 
145
+ @staticmethod
146
+ def from_json(data: dict[str, Any]):
147
+ """Create a Schedule from a JSON-like dictionary."""
148
+ classname = data.get("name")
149
+ if classname == "LinearSchedule":
150
+ return LinearSchedule(data["start_value"], data["end_value"], data["n_steps"])
151
+ elif classname == "ExpSchedule":
152
+ return ExpSchedule(data["start_value"], data["end_value"], data["n_steps"])
153
+ elif classname == "ConstantSchedule":
154
+ return ConstantSchedule(data["value"])
155
+ elif classname == "ArbitrarySchedule":
156
+ raise NotImplementedError("ArbitrarySchedule cannot be deserialized from JSON")
157
+ else:
158
+ raise ValueError(f"Unknown schedule type: {classname}")
159
+
145
160
 
146
161
  @dataclass(eq=False)
147
162
  class LinearSchedule(Schedule):
@@ -8,9 +8,6 @@ from marlenv.adapters import PymarlAdapter
8
8
  skip_gym = not marlenv.adapters.HAS_GYM
9
9
  skip_pettingzoo = not marlenv.adapters.HAS_PETTINGZOO
10
10
  skip_smac = not marlenv.adapters.HAS_SMAC
11
- # Check for "overcooked_ai_py.mdp" specifically because after uninstalling, the package
12
- # can still be found because of some remaining .pkl file.
13
- skip_overcooked = not marlenv.adapters.HAS_OVERCOOKED
14
11
 
15
12
 
16
13
  @pytest.mark.skipif(skip_gym, reason="Gymnasium is not installed")
@@ -80,10 +77,9 @@ def test_pettingzoo_adapter_discrete_action():
80
77
 
81
78
  @pytest.mark.skipif(skip_pettingzoo, reason="PettingZoo is not installed")
82
79
  def test_pettingzoo_adapter_continuous_action():
83
- from pettingzoo.sisl import waterworld_v4
80
+ from pettingzoo.mpe import simple_v3
84
81
 
85
- # https://pettingzoo.farama.org/environments/sisl/waterworld/
86
- env = marlenv.adapters.PettingZoo(waterworld_v4.parallel_env())
82
+ env = marlenv.adapters.PettingZoo(simple_v3.parallel_env(continuous_actions=True))
87
83
  env.reset()
88
84
  action = env.action_space.sample()
89
85
  step = env.step(action)
@@ -93,8 +89,8 @@ def test_pettingzoo_adapter_continuous_action():
93
89
  assert isinstance(step.done, bool)
94
90
  assert isinstance(step.truncated, bool)
95
91
  assert isinstance(step.info, dict)
96
- assert env.n_actions == 2
97
- assert env.n_agents == 2
92
+ assert env.n_actions == 5
93
+ assert env.n_agents == 1
98
94
  assert isinstance(env.action_space, ContinuousSpace)
99
95
 
100
96
 
@@ -135,105 +131,6 @@ def test_smac_render():
135
131
  env.render()
136
132
 
137
133
 
138
- @pytest.mark.skipif(skip_overcooked, reason="Overcooked is not installed")
139
- def test_overcooked_attributes():
140
- from overcooked_ai_py.mdp.overcooked_mdp import Action
141
-
142
- from marlenv.adapters import Overcooked
143
-
144
- env = Overcooked.from_layout("simple_o")
145
- height, width = env._mdp.shape
146
- assert env.n_agents == 2
147
- assert env.n_actions == Action.NUM_ACTIONS
148
- assert env.observation_shape == (25, height, width)
149
- assert env.reward_space.shape == (1,)
150
- assert env.extras_shape == (2,)
151
- assert not env.is_multi_objective
152
-
153
-
154
- @pytest.mark.skipif(skip_overcooked, reason="Overcooked is not installed")
155
- def test_overcooked_obs_state():
156
- from marlenv.adapters import Overcooked
157
-
158
- HORIZON = 100
159
- env = Overcooked.from_layout("coordination_ring", horizon=HORIZON)
160
- height, width = env._mdp.shape
161
- obs, state = env.reset()
162
- for i in range(HORIZON):
163
- assert obs.data.dtype == np.float32
164
- assert state.data.dtype == np.float32
165
- assert obs.extras.dtype == np.float32
166
- assert state.extras.dtype == np.float32
167
- assert obs.shape == (25, height, width)
168
- assert obs.extras_shape == (2,)
169
- assert state.shape == (25, height, width)
170
- assert state.extras_shape == (2,)
171
-
172
- assert np.all(obs.extras[:, 0] == i / HORIZON)
173
- assert np.all(state.extras[0] == i / HORIZON)
174
-
175
- step = env.random_step()
176
- obs = step.obs
177
- state = step.state
178
- if i < HORIZON - 1:
179
- assert not step.done
180
- else:
181
- assert step.done
182
-
183
-
184
- @pytest.mark.skipif(skip_overcooked, reason="Overcooked is not installed")
185
- def test_overcooked_shaping():
186
- from marlenv.adapters import Overcooked
187
-
188
- UP = 0
189
- RIGHT = 2
190
- STAY = 4
191
- INTERACT = 5
192
- grid = [
193
- ["X", "X", "X", "D", "X"],
194
- ["X", "O", "S", "2", "X"],
195
- ["X", "1", "P", " ", "X"],
196
- ["X", "T", "S", " ", "X"],
197
- ["X", "X", "X", "X", "X"],
198
- ]
199
-
200
- env = Overcooked.from_grid(grid)
201
- env.reset()
202
- actions_rewards = [
203
- ([UP, STAY], False),
204
- ([INTERACT, STAY], False),
205
- ([RIGHT, STAY], False),
206
- ([INTERACT, STAY], True),
207
- ]
208
-
209
- for action, expected_reward in actions_rewards:
210
- step = env.step(action)
211
- if expected_reward:
212
- assert step.reward.item() > 0
213
-
214
-
215
- @pytest.mark.skipif(skip_overcooked, reason="Overcooked is not installed")
216
- def test_overcooked_name():
217
- from marlenv.adapters import Overcooked
218
-
219
- grid = [
220
- ["X", "X", "X", "D", "X"],
221
- ["X", "O", "S", "2", "X"],
222
- ["X", "1", "P", " ", "X"],
223
- ["X", "T", "S", " ", "X"],
224
- ["X", "X", "X", "X", "X"],
225
- ]
226
-
227
- env = Overcooked.from_grid(grid)
228
- assert env.name == "Overcooked-custom-layout"
229
-
230
- env = Overcooked.from_grid(grid, layout_name="my incredible grid")
231
- assert env.name == "Overcooked-my incredible grid"
232
-
233
- env = Overcooked.from_layout("asymmetric_advantages")
234
- assert env.name == "Overcooked-asymmetric_advantages"
235
-
236
-
237
134
  def test_pymarl():
238
135
  LIMIT = 20
239
136
  N_AGENTS = 2
@@ -0,0 +1,41 @@
1
+ import pytest
2
+ from marlenv import catalog
3
+ from marlenv.utils import DummyClass, dummy_function
4
+
5
+ skip_lle = isinstance(catalog.LLE, DummyClass)
6
+ skip_overcooked = isinstance(catalog.Overcooked, DummyClass)
7
+
8
+
9
+ @pytest.mark.skipif(skip_lle, reason="LLE is not installed")
10
+ def test_lle():
11
+ catalog.LLE.level(1)
12
+
13
+
14
+ @pytest.mark.skipif(skip_overcooked, reason="Overcooked is not installed")
15
+ def test_overcooked():
16
+ catalog.Overcooked.from_layout("scenario4")
17
+
18
+
19
+ def test_dummy_class():
20
+ try:
21
+ x = DummyClass("")
22
+ x.abc
23
+ assert False, "Expected ImportError upon usage because DummyClass is not installed"
24
+ except ImportError:
25
+ pass
26
+
27
+ try:
28
+ x = DummyClass("")
29
+ x.abc()
30
+ assert False, "Expected ImportError upon usage because DummyClass is not installed"
31
+ except ImportError:
32
+ pass
33
+
34
+
35
+ def test_dummy_function():
36
+ try:
37
+ f = dummy_function("")
38
+ f()
39
+ assert False, "Expected ImportError upon usage because dummy_function is not installed"
40
+ except ImportError:
41
+ pass
@@ -0,0 +1,6 @@
1
+ import marlenv
2
+
3
+
4
+ def test_version():
5
+ assert hasattr(marlenv, "__version__")
6
+ x, y, z = marlenv.__version__.split(".")
@@ -153,3 +153,18 @@ def test_inequality_different_schedules():
153
153
  s3 = Schedule.linear(1, 2, 10)
154
154
  assert s1 != s2
155
155
  assert not s1 != s3
156
+
157
+
158
+ def test_from_json():
159
+ json_data = {"name": "LinearSchedule", "start_value": 0, "end_value": 1, "n_steps": 10}
160
+ s = Schedule.from_json(json_data)
161
+ assert isinstance(s, Schedule)
162
+ assert is_close(s.value, 0)
163
+ s.update()
164
+ assert is_close(s.value, 0.1)
165
+ s.update(5)
166
+ assert is_close(s.value, 0.5)
167
+ s.update(10)
168
+ assert is_close(s.value, 1.0)
169
+ s.update(15)
170
+ assert is_close(s.value, 1.0)
@@ -1,9 +1,8 @@
1
1
  import pickle
2
+
2
3
  import numpy as np
3
4
  import orjson
4
5
  import pytest
5
- import os
6
- from copy import deepcopy
7
6
 
8
7
  import marlenv
9
8
  from marlenv import DiscreteMockEnv, wrappers
@@ -193,87 +192,6 @@ def test_serialize_episode():
193
192
  _ = orjson.dumps(episode, option=orjson.OPT_SERIALIZE_NUMPY)
194
193
 
195
194
 
196
- @pytest.mark.skipif(not marlenv.adapters.HAS_OVERCOOKED, reason="Overcooked is not installed")
197
- def test_deepcopy_overcooked():
198
- env = marlenv.adapters.Overcooked.from_layout("scenario4")
199
- env2 = deepcopy(env)
200
- assert env == env2
201
-
202
-
203
- @pytest.mark.skipif(not marlenv.adapters.HAS_OVERCOOKED, reason="Overcooked is not installed")
204
- def test_deepcopy_overcooked_schedule():
205
- env = marlenv.adapters.Overcooked.from_layout("scenario4", reward_shaping_factor=Schedule.linear(1, 0, 10))
206
- env2 = deepcopy(env)
207
- assert env == env2
208
-
209
- env.random_step()
210
- assert not env == env2, "The reward shaping factor should be different"
211
-
212
-
213
- @pytest.mark.skipif(not marlenv.adapters.HAS_OVERCOOKED, reason="Overcooked is not installed")
214
- def test_pickle_overcooked():
215
- env = marlenv.adapters.Overcooked.from_layout("scenario1_s", horizon=60)
216
- serialized = pickle.dumps(env)
217
- restored = pickle.loads(serialized)
218
- assert env == restored
219
-
220
- env.reset()
221
- restored.reset()
222
-
223
- for _ in range(50):
224
- actions = env.sample_action()
225
- step = env.step(actions)
226
- step_restored = restored.step(actions)
227
- assert step == step_restored
228
-
229
-
230
- @pytest.mark.skipif(not marlenv.adapters.HAS_OVERCOOKED, reason="Overcooked is not installed")
231
- def test_unpickling_from_blank_process():
232
- from marlenv.adapters import Overcooked
233
- import pickle
234
- import subprocess
235
- import tempfile
236
-
237
- env = Overcooked.from_layout("large_room")
238
- env_file = tempfile.NamedTemporaryFile("wb", delete=False)
239
- pickle.dump(env, env_file)
240
- env_file.close()
241
-
242
- # Write the python file
243
-
244
- f = tempfile.NamedTemporaryFile("w", delete=False)
245
- f.write("""
246
- import pickle
247
- import sys
248
-
249
- with open(sys.argv[1], "rb") as f:
250
- env = pickle.load(f)
251
-
252
- env.reset()""")
253
- f.close()
254
- try:
255
- output = subprocess.run(f"python {f.name} {env_file.name}", shell=True, capture_output=True)
256
- assert output.returncode == 0, output.stderr.decode("utf-8")
257
- finally:
258
- os.remove(f.name)
259
- os.remove(env_file.name)
260
-
261
-
262
- @pytest.mark.skipif(not marlenv.adapters.HAS_OVERCOOKED, reason="Overcooked is not installed")
263
- def test_serialize_json_overcooked():
264
- env = marlenv.adapters.Overcooked.from_layout("scenario1_s", horizon=60)
265
- res = orjson.dumps(env, option=orjson.OPT_SERIALIZE_NUMPY)
266
- deserialized = orjson.loads(res)
267
-
268
- assert deserialized["n_agents"] == env.n_agents
269
- assert tuple(deserialized["observation_shape"]) == env.observation_shape
270
- assert tuple(deserialized["state_shape"]) == env.state_shape
271
- assert tuple(deserialized["extras_shape"]) == env.extras_shape
272
- assert deserialized["n_actions"] == env.n_actions
273
- assert deserialized["name"] == env.name
274
- assert deserialized["extras_meanings"] == env.extras_meanings
275
-
276
-
277
195
  @pytest.mark.skipif(not marlenv.adapters.HAS_GYM, reason="Gymnasium is not installed")
278
196
  def test_json_serialize_gym():
279
197
  env = marlenv.make("CartPole-v1")
@@ -1,42 +0,0 @@
1
- from importlib.util import find_spec
2
- from .pymarl_adapter import PymarlAdapter
3
-
4
- HAS_GYM = False
5
- if find_spec("gymnasium") is not None:
6
- from .gym_adapter import Gym
7
-
8
- HAS_GYM = True
9
-
10
- HAS_PETTINGZOO = False
11
- if find_spec("pettingzoo") is not None:
12
- from .pettingzoo_adapter import PettingZoo
13
-
14
- HAS_PETTINGZOO = True
15
-
16
- HAS_SMAC = False
17
- if find_spec("smac") is not None:
18
- from .smac_adapter import SMAC
19
-
20
- HAS_SMAC = True
21
-
22
- HAS_OVERCOOKED = False
23
- if find_spec("overcooked_ai_py") is not None and find_spec("overcooked_ai_py.mdp") is not None:
24
- import numpy
25
-
26
- # Overcooked assumes a version of numpy <2.0 where np.Inf is available.
27
- setattr(numpy, "Inf", numpy.inf)
28
- from .overcooked_adapter import Overcooked
29
-
30
- HAS_OVERCOOKED = True
31
-
32
- __all__ = [
33
- "PymarlAdapter",
34
- "Gym",
35
- "PettingZoo",
36
- "SMAC",
37
- "Overcooked",
38
- "HAS_GYM",
39
- "HAS_PETTINGZOO",
40
- "HAS_SMAC",
41
- "HAS_OVERCOOKED",
42
- ]
@@ -1,241 +0,0 @@
1
- import sys
2
- from dataclasses import dataclass
3
- from typing import Literal, Sequence, Optional
4
- from copy import deepcopy
5
-
6
- import cv2
7
- import numpy as np
8
- import numpy.typing as npt
9
- import pygame
10
- from marlenv.models import ContinuousSpace, DiscreteSpace, MARLEnv, Observation, State, Step, MultiDiscreteSpace
11
- from marlenv.utils import Schedule
12
-
13
- from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
14
- from overcooked_ai_py.mdp.overcooked_mdp import Action, OvercookedGridworld, OvercookedState
15
- from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
16
-
17
-
18
- @dataclass
19
- class Overcooked(MARLEnv[MultiDiscreteSpace]):
20
- horizon: int
21
- shaping_factor: Schedule
22
-
23
- def __init__(
24
- self,
25
- oenv: OvercookedEnv,
26
- shaping_factor: float | Schedule = 1.0,
27
- name_suffix: Optional[str] = None,
28
- ):
29
- if isinstance(shaping_factor, (int, float)):
30
- shaping_factor = Schedule.constant(shaping_factor)
31
- self.shaping_factor = shaping_factor
32
- self._oenv = oenv
33
- assert isinstance(oenv.mdp, OvercookedGridworld)
34
- self._mdp = oenv.mdp
35
- self._visualizer = StateVisualizer()
36
- width, height, layers = tuple(self._mdp.lossless_state_encoding_shape)
37
- # -1 because we extract the "urgent" layer to the extras
38
- shape = (int(layers - 1), int(width), int(height))
39
- super().__init__(
40
- n_agents=self._mdp.num_players,
41
- action_space=DiscreteSpace(Action.NUM_ACTIONS, labels=[Action.ACTION_TO_CHAR[a] for a in Action.ALL_ACTIONS]).repeat(
42
- self._mdp.num_players
43
- ),
44
- observation_shape=shape,
45
- extras_shape=(2,),
46
- extras_meanings=["timestep", "urgent"],
47
- state_shape=shape,
48
- state_extra_shape=(2,),
49
- reward_space=ContinuousSpace.from_shape(1),
50
- )
51
- self.horizon = int(self._oenv.horizon)
52
- if name_suffix is not None:
53
- self.name = f"{self.name}-{name_suffix}"
54
-
55
- @property
56
- def state(self) -> OvercookedState:
57
- """Current state of the environment"""
58
- return self._oenv.state
59
-
60
- def set_state(self, state: State):
61
- raise NotImplementedError("Not yet implemented")
62
-
63
- @property
64
- def time_step(self):
65
- return self.state.timestep
66
-
67
- def _state_data(self):
68
- players_layers = self._mdp.lossless_state_encoding(self.state)
69
- state = np.array(players_layers, dtype=np.float32)
70
- # Use axes (agents, channels, height, width) instead of (agents, height, width, channels)
71
- state = np.transpose(state, (0, 3, 1, 2))
72
- # The last last layer is for "urgency", put it in the extras
73
- urgency = float(np.all(state[:, -1]))
74
- state = state[:, :-1]
75
- return state, urgency
76
-
77
- def get_state(self):
78
- data, is_urgent = self._state_data()
79
- return State(data[0], np.array([self.time_step / self.horizon, is_urgent], dtype=np.float32))
80
-
81
- def get_observation(self) -> Observation:
82
- data, is_urgent = self._state_data()
83
- return Observation(
84
- data=data,
85
- available_actions=self.available_actions(),
86
- extras=np.array([[self.time_step / self.horizon, is_urgent]] * self.n_agents, dtype=np.float32),
87
- )
88
-
89
- def available_actions(self):
90
- available_actions = np.full((self.n_agents, self.n_actions), False)
91
- actions = self._mdp.get_actions(self._oenv.state)
92
- for agent_num, agent_actions in enumerate(actions):
93
- for action in agent_actions:
94
- available_actions[agent_num, Action.ACTION_TO_INDEX[action]] = True
95
- return np.array(available_actions, dtype=np.bool)
96
-
97
- def step(self, actions: Sequence[int] | np.ndarray) -> Step:
98
- self.shaping_factor.update()
99
- actions = [Action.ALL_ACTIONS[a] for a in actions]
100
- _, reward, done, info = self._oenv.step(actions, display_phi=True)
101
-
102
- reward += sum(info["shaped_r_by_agent"]) * self.shaping_factor
103
- return Step(
104
- obs=self.get_observation(),
105
- state=self.get_state(),
106
- reward=np.array([reward], dtype=np.float32),
107
- done=done,
108
- truncated=False,
109
- info=info,
110
- )
111
-
112
- def reset(self):
113
- self._oenv.reset()
114
- return self.get_observation(), self.get_state()
115
-
116
- def __deepcopy__(self, _):
117
- """
118
- Note: a specific implementation is needed because `pygame.font.Font` objects are not deep-copiable by default.
119
- """
120
- mdp = deepcopy(self._mdp)
121
- copy = Overcooked(OvercookedEnv.from_mdp(mdp, horizon=self.horizon), deepcopy(self.shaping_factor))
122
- copy.name = self.name
123
- return copy
124
-
125
- def __getstate__(self):
126
- return {"horizon": self.horizon, "mdp": self._mdp, "name": self.name, "schedule": self.shaping_factor}
127
-
128
- def __setstate__(self, state: dict):
129
- from overcooked_ai_py.mdp.overcooked_mdp import Recipe
130
-
131
- mdp = state["mdp"]
132
- Recipe.configure(mdp.recipe_config)
133
- self.__init__(OvercookedEnv.from_mdp(state["mdp"], horizon=state["horizon"]), shaping_factor=state["schedule"])
134
- self.name = state["name"]
135
-
136
- def get_image(self):
137
- rewards_dict = {} # dictionary of details you want rendered in the UI
138
- for key, value in self._oenv.game_stats.items():
139
- if key in [
140
- "cumulative_shaped_rewards_by_agent",
141
- "cumulative_sparse_rewards_by_agent",
142
- ]:
143
- rewards_dict[key] = value
144
-
145
- image = self._visualizer.render_state(
146
- state=self._oenv.state,
147
- grid=self._mdp.terrain_mtx,
148
- hud_data=StateVisualizer.default_hud_data(self._oenv.state, **rewards_dict),
149
- )
150
-
151
- image = pygame.surfarray.array3d(image)
152
- image = np.flip(np.rot90(image, 3), 1)
153
- # Depending on the platform, the image may need to be converted to RGB
154
- if sys.platform in ("linux", "linux2"):
155
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
156
- return image
157
-
158
- @staticmethod
159
- def from_layout(
160
- layout: Literal[
161
- "asymmetric_advantages",
162
- "asymmetric_advantages_tomato",
163
- "bonus_order_test",
164
- "bottleneck",
165
- "centre_objects",
166
- "centre_pots",
167
- "coordination_ring",
168
- "corridor",
169
- "counter_circuit",
170
- "counter_circuit_o_1order",
171
- "cramped_corridor",
172
- "cramped_room",
173
- "cramped_room_o_3orders",
174
- "cramped_room_single",
175
- "cramped_room_tomato",
176
- "five_by_five",
177
- "forced_coordination",
178
- "forced_coordination_tomato",
179
- "inverse_marshmallow_experiment",
180
- "large_room",
181
- "long_cook_time",
182
- "marshmallow_experiment_coordination",
183
- "marshmallow_experiment",
184
- "mdp_test",
185
- "m_shaped_s",
186
- "multiplayer_schelling",
187
- "pipeline",
188
- "scenario1_s",
189
- "scenario2",
190
- "scenario2_s",
191
- "scenario3",
192
- "scenario4",
193
- "schelling",
194
- "schelling_s",
195
- "simple_o",
196
- "simple_o_t",
197
- "simple_tomato",
198
- "small_corridor",
199
- "soup_coordination",
200
- "tutorial_0",
201
- "tutorial_1",
202
- "tutorial_2",
203
- "tutorial_3",
204
- "unident",
205
- "you_shall_not_pass",
206
- ],
207
- horizon: int = 400,
208
- reward_shaping_factor: float | Schedule = 1.0,
209
- ):
210
- mdp = OvercookedGridworld.from_layout_name(layout)
211
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), reward_shaping_factor, layout)
212
-
213
- @staticmethod
214
- def from_grid(
215
- grid: Sequence[Sequence[Literal["S", "P", "X", "O", "D", "T", "1", "2", " "] | str]],
216
- horizon: int = 400,
217
- shaping_factor: float | Schedule = 1.0,
218
- layout_name: Optional[str] = None,
219
- ):
220
- """
221
- Create an Overcooked environment from a grid layout where
222
- - S is a serving location
223
- - P is a cooking pot
224
- - X is a counter
225
- - O is an onion dispenser
226
- - D is a dish dispenser
227
- - T is a tomato dispenser
228
- - 1 is a player 1 starting location
229
- - 2 is a player 2 starting location
230
- - ' ' is a walkable space
231
-
232
- If provided, `custom_name` is added to the environment name.
233
- """
234
- # It is necessary to add an explicit layout name because Overcooked saves some files under this
235
- # name. By default the name is a concatenation of the grid elements, which may include characters
236
- # such as white spaces, pipes ('|') and square brackets ('[' and ']') that are invalid Windows file paths.
237
- if layout_name is None:
238
- layout_name = "custom-layout"
239
- mdp = OvercookedGridworld.from_grid(grid, base_layout_params={"layout_name": layout_name})
240
-
241
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), shaping_factor, layout_name)