multi-agent-rlenv 3.5.4__tar.gz → 3.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/PKG-INFO +23 -6
  2. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/README.md +17 -2
  3. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/pyproject.toml +5 -6
  4. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/__init__.py +12 -3
  5. multi_agent_rlenv-3.6.0/src/marlenv/adapters/__init__.py +33 -0
  6. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/adapters/gym_adapter.py +6 -0
  7. multi_agent_rlenv-3.6.0/src/marlenv/catalog/__init__.py +26 -0
  8. multi_agent_rlenv-3.6.0/src/marlenv/catalog/deepsea.py +73 -0
  9. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/env_builder.py +1 -61
  10. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/models/episode.py +9 -7
  11. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/models/spaces.py +1 -1
  12. multi_agent_rlenv-3.6.0/src/marlenv/utils/__init__.py +15 -0
  13. multi_agent_rlenv-3.6.0/src/marlenv/utils/cached_property_collector.py +17 -0
  14. multi_agent_rlenv-3.6.0/src/marlenv/utils/import_placeholders.py +30 -0
  15. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/utils/schedule.py +16 -1
  16. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/__init__.py +2 -0
  17. multi_agent_rlenv-3.6.0/src/marlenv/wrappers/action_randomizer.py +17 -0
  18. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/tests/test_adapters.py +4 -107
  19. multi_agent_rlenv-3.6.0/tests/test_catalog.py +41 -0
  20. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/tests/test_episode.py +31 -2
  21. multi_agent_rlenv-3.6.0/tests/test_others.py +6 -0
  22. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/tests/test_schedules.py +16 -1
  23. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/tests/test_serialization.py +1 -83
  24. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/tests/test_wrappers.py +1 -1
  25. multi_agent_rlenv-3.5.4/src/marlenv/adapters/__init__.py +0 -42
  26. multi_agent_rlenv-3.5.4/src/marlenv/adapters/overcooked_adapter.py +0 -241
  27. multi_agent_rlenv-3.5.4/src/marlenv/utils/__init__.py +0 -10
  28. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/.github/workflows/ci.yaml +0 -0
  29. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/.github/workflows/docs.yaml +0 -0
  30. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/.gitignore +0 -0
  31. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/LICENSE +0 -0
  32. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/adapters/pettingzoo_adapter.py +0 -0
  33. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/adapters/pymarl_adapter.py +0 -0
  34. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/adapters/smac_adapter.py +0 -0
  35. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/env_pool.py +0 -0
  36. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/exceptions.py +0 -0
  37. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/mock_env.py +0 -0
  38. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/models/__init__.py +0 -0
  39. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/models/env.py +0 -0
  40. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/models/observation.py +0 -0
  41. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/models/state.py +0 -0
  42. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/models/step.py +0 -0
  43. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/models/transition.py +0 -0
  44. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/py.typed +0 -0
  45. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/agent_id_wrapper.py +0 -0
  46. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/available_actions_mask.py +0 -0
  47. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/available_actions_wrapper.py +0 -0
  48. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/blind_wrapper.py +0 -0
  49. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/centralised.py +0 -0
  50. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/delayed_rewards.py +0 -0
  51. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/last_action_wrapper.py +0 -0
  52. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/paddings.py +0 -0
  53. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/penalty_wrapper.py +0 -0
  54. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/potential_shaping.py +0 -0
  55. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/rlenv_wrapper.py +0 -0
  56. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/time_limit.py +0 -0
  57. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/src/marlenv/wrappers/video_recorder.py +0 -0
  58. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/tests/__init__.py +0 -0
  59. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/tests/test_models.py +0 -0
  60. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/tests/test_pool.py +0 -0
  61. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/tests/test_spaces.py +0 -0
  62. {multi_agent_rlenv-3.5.4 → multi_agent_rlenv-3.6.0}/tests/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.5.4
3
+ Version: 3.6.0
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
@@ -13,7 +13,8 @@ Requires-Dist: opencv-python>=4.0
13
13
  Requires-Dist: typing-extensions>=4.0
14
14
  Provides-Extra: all
15
15
  Requires-Dist: gymnasium>0.29.1; extra == 'all'
16
- Requires-Dist: overcooked-ai; extra == 'all'
16
+ Requires-Dist: laser-learning-environment>=2.6.1; extra == 'all'
17
+ Requires-Dist: overcooked>=0.1.0; extra == 'all'
17
18
  Requires-Dist: pettingzoo>=1.20; extra == 'all'
18
19
  Requires-Dist: pymunk>=6.0; extra == 'all'
19
20
  Requires-Dist: pysc2; extra == 'all'
@@ -22,9 +23,10 @@ Requires-Dist: smac; extra == 'all'
22
23
  Requires-Dist: torch>=2.0; extra == 'all'
23
24
  Provides-Extra: gym
24
25
  Requires-Dist: gymnasium>=0.29.1; extra == 'gym'
26
+ Provides-Extra: lle
27
+ Requires-Dist: laser-learning-environment>=2.6.1; extra == 'lle'
25
28
  Provides-Extra: overcooked
26
- Requires-Dist: overcooked-ai>=1.1.0; extra == 'overcooked'
27
- Requires-Dist: scipy>=1.10; extra == 'overcooked'
29
+ Requires-Dist: overcooked>=0.1.0; extra == 'overcooked'
28
30
  Provides-Extra: pettingzoo
29
31
  Requires-Dist: pettingzoo>=1.20; extra == 'pettingzoo'
30
32
  Requires-Dist: pymunk>=6.0; extra == 'pettingzoo'
@@ -62,9 +64,24 @@ $ pip install marlenv[smac] # Install SMAC
62
64
  $ pip install marlenv[gym,smac] # Install Gym & smac support
63
65
  ```
64
66
 
67
+ ## Using the `marlenv` environment catalog
68
+ Some environments are registered in the `marlenv` and can be easily instantiated via its catalog.
69
+
70
+ ```python
71
+ from marlenv import catalog
72
+
73
+ env1 = catalog.Overcooked.from_layout("scenario4")
74
+ env2 = catalog.LLE.level(6)
75
+ env3 = catalog.DeepSea(mex_depth=5)
76
+ ```
77
+ Note that using the catalog requires the corresponding environment package to be installed. For instance you need to install the `laser-learning-environment` package to use `catalog.LLE`, which can be done by using the corresponding feature when at installation as shown below.
78
+ ```bash
79
+ pip install multi-agent-rlenv[lle]
80
+ ```
81
+
65
82
 
66
83
  ## Using `marlenv` with existing libraries
67
- `marlenv` unifies multiple popular libraries under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
84
+ `marlenv` provides adapters from most popular libraries to unify them under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
68
85
 
69
86
  ```python
70
87
  import marlenv
@@ -74,7 +91,7 @@ gym_env = marlenv.make("CartPole-v1", seed=25)
74
91
 
75
92
  # You can seemlessly instanciate a SMAC environment and directly pass your required arguments
76
93
  from marlenv.adapters import SMAC
77
- smac_env = env2 = SMAC("3m", debug=True, difficulty="9")
94
+ smac_env = SMAC("3m", debug=True, difficulty="9")
78
95
 
79
96
  # pettingzoo is also supported
80
97
  from pettingzoo.sisl import pursuit_v4
@@ -24,9 +24,24 @@ $ pip install marlenv[smac] # Install SMAC
24
24
  $ pip install marlenv[gym,smac] # Install Gym & smac support
25
25
  ```
26
26
 
27
+ ## Using the `marlenv` environment catalog
28
+ Some environments are registered in the `marlenv` and can be easily instantiated via its catalog.
29
+
30
+ ```python
31
+ from marlenv import catalog
32
+
33
+ env1 = catalog.Overcooked.from_layout("scenario4")
34
+ env2 = catalog.LLE.level(6)
35
+ env3 = catalog.DeepSea(mex_depth=5)
36
+ ```
37
+ Note that using the catalog requires the corresponding environment package to be installed. For instance you need to install the `laser-learning-environment` package to use `catalog.LLE`, which can be done by using the corresponding feature when at installation as shown below.
38
+ ```bash
39
+ pip install multi-agent-rlenv[lle]
40
+ ```
41
+
27
42
 
28
43
  ## Using `marlenv` with existing libraries
29
- `marlenv` unifies multiple popular libraries under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
44
+ `marlenv` provides adapters from most popular libraries to unify them under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
30
45
 
31
46
  ```python
32
47
  import marlenv
@@ -36,7 +51,7 @@ gym_env = marlenv.make("CartPole-v1", seed=25)
36
51
 
37
52
  # You can seemlessly instanciate a SMAC environment and directly pass your required arguments
38
53
  from marlenv.adapters import SMAC
39
- smac_env = env2 = SMAC("3m", debug=True, difficulty="9")
54
+ smac_env = SMAC("3m", debug=True, difficulty="9")
40
55
 
41
56
  # pettingzoo is also supported
42
57
  from pettingzoo.sisl import pursuit_v4
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "multi-agent-rlenv"
3
- dynamic = ["version"]
3
+ version = "3.6.0"
4
4
  description = "A strongly typed Multi-Agent Reinforcement Learning framework"
5
5
  authors = [
6
6
  { "name" = "Yannick Molinghen", "email" = "yannick.molinghen@ulb.be" },
@@ -19,14 +19,16 @@ dependencies = ["numpy>=2.0.0", "opencv-python>=4.0", "typing_extensions>=4.0"]
19
19
  gym = ["gymnasium>=0.29.1"]
20
20
  smac = ["smac", "pysc2"]
21
21
  pettingzoo = ["pettingzoo>=1.20", "pymunk>=6.0", "scipy>=1.10"]
22
- overcooked = ["overcooked-ai>=1.1.0", "scipy>=1.10"]
22
+ overcooked = ["overcooked>=0.1.0"]
23
+ lle = ["laser-learning-environment>=2.6.1"]
23
24
  torch = ["torch>=2.0"]
24
25
  all = [
25
26
  "gymnasium>0.29.1",
26
27
  "pettingzoo>=1.20",
27
- "overcooked-ai",
28
+ "overcooked>=0.1.0",
28
29
  "smac",
29
30
  "pysc2",
31
+ "laser-learning-environment>=2.6.1",
30
32
  "pymunk>=6.0",
31
33
  "scipy>=1.10",
32
34
  "torch>=2.0",
@@ -42,9 +44,6 @@ line-length = 140
42
44
  [tool.hatch.build.targets.wheel]
43
45
  packages = ["src/marlenv"]
44
46
 
45
- [tool.hatch]
46
- version = { "path" = "src/marlenv/__init__.py", "attr" = "__version__" }
47
-
48
47
 
49
48
  [tool.pytest.ini_options]
50
49
  testpaths = ["tests"]
@@ -62,7 +62,13 @@ print(env.extras_shape) # (1, )
62
62
  If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
63
63
  """
64
64
 
65
- __version__ = "3.5.4"
65
+ from importlib.metadata import version, PackageNotFoundError
66
+
67
+ try:
68
+ __version__ = version("overcooked")
69
+ except PackageNotFoundError:
70
+ __version__ = "0.0.0" # fallback pratique en dev/CI
71
+
66
72
 
67
73
  from . import models
68
74
  from .models import (
@@ -82,16 +88,19 @@ from .models import (
82
88
 
83
89
  from . import wrappers
84
90
  from . import adapters
85
- from .env_builder import make, Builder
91
+ from .env_builder import Builder
86
92
  from .wrappers import RLEnvWrapper
87
93
  from .mock_env import DiscreteMockEnv, DiscreteMOMockEnv
94
+ from . import catalog
95
+ from .adapters import make
88
96
 
89
97
  __all__ = [
90
98
  "models",
99
+ "make",
100
+ "catalog",
91
101
  "wrappers",
92
102
  "adapters",
93
103
  "spaces",
94
- "make",
95
104
  "Builder",
96
105
  "MARLEnv",
97
106
  "Step",
@@ -0,0 +1,33 @@
1
+ from importlib.util import find_spec
2
+ from .pymarl_adapter import PymarlAdapter
3
+ from marlenv.utils import DummyClass, dummy_function
4
+
5
+ HAS_GYM = find_spec("gymnasium") is not None
6
+ if HAS_GYM:
7
+ from .gym_adapter import Gym, make
8
+ else:
9
+ Gym = DummyClass("gymnasium")
10
+ make = dummy_function("gymnasium")
11
+
12
+ HAS_PETTINGZOO = find_spec("pettingzoo") is not None
13
+ if HAS_PETTINGZOO:
14
+ from .pettingzoo_adapter import PettingZoo
15
+ else:
16
+ PettingZoo = DummyClass("pettingzoo")
17
+
18
+ HAS_SMAC = find_spec("smac") is not None
19
+ if HAS_SMAC:
20
+ from .smac_adapter import SMAC
21
+ else:
22
+ SMAC = DummyClass("smac", "https://github.com/oxwhirl/smac.git")
23
+
24
+ __all__ = [
25
+ "PymarlAdapter",
26
+ "Gym",
27
+ "make",
28
+ "PettingZoo",
29
+ "SMAC",
30
+ "HAS_GYM",
31
+ "HAS_PETTINGZOO",
32
+ "HAS_SMAC",
33
+ ]
@@ -78,3 +78,9 @@ class Gym(MARLEnv[Space]):
78
78
 
79
79
  def seed(self, seed_value: int):
80
80
  self._gym_env.reset(seed=seed_value)
81
+
82
+
83
+ def make(env_id: str, **kwargs):
84
+ """Make an RLEnv from str (Gym) or PettingZoo"""
85
+ gym_env = gym.make(env_id, render_mode="rgb_array", **kwargs)
86
+ return Gym(gym_env)
@@ -0,0 +1,26 @@
1
+ from importlib.util import find_spec
2
+ from ..utils.import_placeholders import DummyClass
3
+ from marlenv.adapters import SMAC
4
+ from .deepsea import DeepSea
5
+
6
+
7
+ HAS_LLE = find_spec("lle") is not None
8
+ if HAS_LLE:
9
+ from lle import LLE # pyright: ignore[reportMissingImports]
10
+ else:
11
+ LLE = DummyClass("lle", "laser-learning-environment")
12
+
13
+ HAS_OVERCOOKED = find_spec("overcooked") is not None
14
+ if HAS_OVERCOOKED:
15
+ from overcooked import Overcooked # pyright: ignore[reportMissingImports]
16
+ else:
17
+ Overcooked = DummyClass("overcooked", "overcooked")
18
+
19
+ __all__ = [
20
+ "Overcooked",
21
+ "SMAC",
22
+ "LLE",
23
+ "DeepSea",
24
+ "HAS_LLE",
25
+ "HAS_OVERCOOKED",
26
+ ]
@@ -0,0 +1,73 @@
1
+ from typing import Sequence
2
+ import numpy as np
3
+ from marlenv import MARLEnv, MultiDiscreteSpace, DiscreteSpace, Observation, State, Step
4
+ from dataclasses import dataclass
5
+
6
+
7
+ LEFT = 0
8
+ RIGHT = 1
9
+
10
+
11
+ @dataclass
12
+ class DeepSea(MARLEnv[MultiDiscreteSpace]):
13
+ """
14
+ Deep Sea single-agent environment to test for deep exploration. The probability of reaching the goal state under random exploration is 2^(-max_depth).
15
+
16
+ The agent explores a 2D grid where the bottom-right corner (max_depth, max_depth) is the goal and is the only state to yield a reward.
17
+ The agent starts in the top-left corner (0, 0).
18
+ The agent has two actions: left or right, and taking an action makes the agent dive one row deeper. The agent can not go beyond the grid boundaries.
19
+ Going right gives a penalty of (0.01 / max_depth).
20
+ """
21
+
22
+ max_depth: int
23
+
24
+ def __init__(self, max_depth: int):
25
+ super().__init__(
26
+ n_agents=1,
27
+ action_space=DiscreteSpace(size=2, labels=["left", "right"]).repeat(1),
28
+ observation_shape=(2,),
29
+ state_shape=(2,),
30
+ )
31
+ self.max_depth = max_depth
32
+ self._row = 0
33
+ self._col = 0
34
+ self._step_right_penalty = -0.01 / self.max_depth
35
+
36
+ def get_observation(self) -> Observation:
37
+ return Observation(np.array([self._row, self._col], dtype=np.float32), self.available_actions())
38
+
39
+ def get_state(self) -> State:
40
+ return State(np.array([self._row, self._col], dtype=np.float32))
41
+
42
+ def reset(self):
43
+ self._row = 0
44
+ self._col = 0
45
+ return self.get_observation(), self.get_state()
46
+
47
+ def step(self, action: Sequence[int]):
48
+ self._row += 1
49
+ if action[0] == LEFT:
50
+ self._col -= 1
51
+ else:
52
+ self._col += 1
53
+ self._col = max(0, self._col)
54
+ if action[0] == RIGHT:
55
+ if self._row == self.max_depth:
56
+ reward = 1.0
57
+ else:
58
+ reward = self._step_right_penalty
59
+ else:
60
+ reward = 0.0
61
+ return Step(
62
+ self.get_observation(),
63
+ self.get_state(),
64
+ reward,
65
+ done=self._row == self.max_depth,
66
+ )
67
+
68
+ def set_state(self, state: State):
69
+ self._row, self._col = state.data
70
+
71
+ @property
72
+ def agent_state_size(self):
73
+ return 2
@@ -1,73 +1,13 @@
1
1
  from dataclasses import dataclass
2
- from typing import Generic, Literal, Optional, TypeVar, overload
2
+ from typing import Generic, Literal, Optional, TypeVar
3
3
  import numpy as np
4
4
  import numpy.typing as npt
5
5
 
6
6
  from . import wrappers
7
- from marlenv import adapters
8
7
  from .models import Space, MARLEnv
9
8
 
10
9
  AS = TypeVar("AS", bound=Space)
11
10
 
12
- if adapters.HAS_PETTINGZOO:
13
- from .adapters import PettingZoo
14
- from pettingzoo import ParallelEnv
15
-
16
- @overload
17
- def make(env: ParallelEnv) -> PettingZoo: ...
18
-
19
-
20
- if adapters.HAS_GYM:
21
- from .adapters import Gym
22
- from gymnasium import Env
23
- import gymnasium
24
-
25
- @overload
26
- def make(env: Env) -> Gym: ...
27
-
28
- @overload
29
- def make(env: str, **kwargs) -> Gym:
30
- """
31
- Make an RLEnv from the `gymnasium` registry (e.g: "CartPole-v1").
32
- """
33
-
34
-
35
- if adapters.HAS_SMAC:
36
- from .adapters import SMAC
37
- from smac.env import StarCraft2Env
38
-
39
- @overload
40
- def make(env: StarCraft2Env) -> SMAC: ...
41
-
42
-
43
- if adapters.HAS_OVERCOOKED:
44
- from .adapters import Overcooked
45
- from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
46
-
47
- @overload
48
- def make(env: OvercookedEnv) -> Overcooked: ...
49
-
50
-
51
- def make(env, **kwargs):
52
- """Make an RLEnv from str (Gym) or PettingZoo"""
53
- match env:
54
- case MARLEnv():
55
- return env
56
- case str(env_id):
57
- if adapters.HAS_GYM:
58
- gym_env = gymnasium.make(env_id, render_mode="rgb_array", **kwargs)
59
- return Gym(gym_env)
60
-
61
- if adapters.HAS_PETTINGZOO and isinstance(env, ParallelEnv):
62
- return PettingZoo(env) # type: ignore
63
- if adapters.HAS_SMAC and isinstance(env, StarCraft2Env):
64
- return SMAC(env)
65
- if adapters.HAS_OVERCOOKED and isinstance(env, OvercookedEnv):
66
- return Overcooked(env) # type: ignore
67
- if adapters.HAS_GYM and isinstance(env, Env):
68
- return Gym(env)
69
- raise ValueError(f"Unknown environment type: {type(env)}")
70
-
71
11
 
72
12
  @dataclass
73
13
  class Builder(Generic[AS]):
@@ -2,20 +2,22 @@ from dataclasses import dataclass
2
2
  from functools import cached_property
3
3
  from typing import Any, Callable, Optional, Sequence, overload
4
4
 
5
+ import cv2
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
- import cv2
8
8
 
9
+ from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
10
+ from marlenv.utils import CachedPropertyInvalidator
11
+
12
+ from .env import MARLEnv
9
13
  from .observation import Observation
10
14
  from .state import State
11
15
  from .step import Step
12
16
  from .transition import Transition
13
- from .env import MARLEnv
14
- from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
15
17
 
16
18
 
17
19
  @dataclass
18
- class Episode:
20
+ class Episode(CachedPropertyInvalidator):
19
21
  """Episode model made of observations, actions, rewards, ..."""
20
22
 
21
23
  all_observations: list[npt.NDArray[np.float32]]
@@ -153,12 +155,12 @@ class Episode:
153
155
  """Get the next extra features"""
154
156
  return self.all_extras[1:]
155
157
 
156
- @cached_property
158
+ @property
157
159
  def n_agents(self):
158
160
  """The number of agents in the episode"""
159
161
  return self.all_extras[0].shape[0]
160
162
 
161
- @cached_property
163
+ @property
162
164
  def n_actions(self):
163
165
  """The number of actions"""
164
166
  return len(self.all_available_actions[0][0])
@@ -267,7 +269,7 @@ class Episode:
267
269
  def __len__(self):
268
270
  return self.episode_len
269
271
 
270
- @cached_property
272
+ @property
271
273
  def score(self) -> list[float]:
272
274
  """The episode score (sum of all rewards across all objectives)"""
273
275
  score = []
@@ -105,7 +105,7 @@ class MultiDiscreteSpace(Space):
105
105
  def sample(self, mask: Optional[npt.NDArray[np.bool] | list[npt.NDArray[np.bool]]] = None):
106
106
  if mask is None:
107
107
  return np.array([space.sample() for space in self.spaces], dtype=np.int32)
108
- return np.array([space.sample(mask=mask) for mask, space in zip(mask, self.spaces)], dtype=np.int32)
108
+ return np.array([space.sample(mask=mask) for mask, space in zip(mask, self.spaces)], dtype=np.int32) # type: ignore
109
109
 
110
110
  def __eq__(self, value: object) -> bool:
111
111
  if not isinstance(value, MultiDiscreteSpace):
@@ -0,0 +1,15 @@
1
+ from .cached_property_collector import CachedPropertyCollector, CachedPropertyInvalidator
2
+ from .schedule import ExpSchedule, LinearSchedule, MultiSchedule, RoundedSchedule, Schedule
3
+ from .import_placeholders import DummyClass, dummy_function
4
+
5
+ __all__ = [
6
+ "Schedule",
7
+ "LinearSchedule",
8
+ "ExpSchedule",
9
+ "MultiSchedule",
10
+ "RoundedSchedule",
11
+ "CachedPropertyCollector",
12
+ "CachedPropertyInvalidator",
13
+ "DummyClass",
14
+ "dummy_function",
15
+ ]
@@ -0,0 +1,17 @@
1
+ from functools import cached_property
2
+
3
+
4
+ class CachedPropertyCollector(type):
5
+ def __init__(cls, name: str, bases: tuple, namespace: dict):
6
+ super().__init__(name, bases, namespace)
7
+ cls.CACHED_PROPERTY_NAMES = [key for key, value in namespace.items() if isinstance(value, cached_property)]
8
+
9
+
10
+ class CachedPropertyInvalidator(metaclass=CachedPropertyCollector):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def invalidate_cached_properties(self):
15
+ for key in self.__class__.CACHED_PROPERTY_NAMES:
16
+ if hasattr(self, key):
17
+ delattr(self, key)
@@ -0,0 +1,30 @@
1
+ from typing import Optional, Any
2
+
3
+
4
+ class DummyClass:
5
+ def __init__(self, module_name: str, package_name: Optional[str] = None):
6
+ self.module_name = module_name
7
+ if package_name is None:
8
+ self.package_name = module_name
9
+ else:
10
+ self.package_name = package_name
11
+
12
+ def _raise_error(self):
13
+ raise ImportError(
14
+ f"The optional dependency `{self.module_name}` is not installed.\nInstall the `{self.package_name}` package (e.g. pip install {self.package_name})."
15
+ )
16
+
17
+ def __getattr__(self, _):
18
+ self._raise_error()
19
+
20
+ def __call__(self, *args, **kwargs):
21
+ self._raise_error()
22
+
23
+
24
+ def dummy_function(module_name: str, package_name: Optional[str] = None):
25
+ dummy = DummyClass(module_name, package_name)
26
+
27
+ def fail(*args, **kwargs) -> Any:
28
+ dummy()
29
+
30
+ return fail
@@ -1,6 +1,6 @@
1
1
  from abc import abstractmethod
2
2
  from dataclasses import dataclass
3
- from typing import Callable, Optional, TypeVar
3
+ from typing import Any, Callable, Optional, TypeVar
4
4
 
5
5
  T = TypeVar("T")
6
6
 
@@ -142,6 +142,21 @@ class Schedule:
142
142
  def __int__(self) -> int:
143
143
  return int(self.value)
144
144
 
145
+ @staticmethod
146
+ def from_json(data: dict[str, Any]):
147
+ """Create a Schedule from a JSON-like dictionary."""
148
+ classname = data.get("name")
149
+ if classname == "LinearSchedule":
150
+ return LinearSchedule(data["start_value"], data["end_value"], data["n_steps"])
151
+ elif classname == "ExpSchedule":
152
+ return ExpSchedule(data["start_value"], data["end_value"], data["n_steps"])
153
+ elif classname == "ConstantSchedule":
154
+ return ConstantSchedule(data["value"])
155
+ elif classname == "ArbitrarySchedule":
156
+ raise NotImplementedError("ArbitrarySchedule cannot be deserialized from JSON")
157
+ else:
158
+ raise ValueError(f"Unknown schedule type: {classname}")
159
+
145
160
 
146
161
  @dataclass(eq=False)
147
162
  class LinearSchedule(Schedule):
@@ -11,6 +11,7 @@ from .centralised import Centralized
11
11
  from .available_actions_mask import AvailableActionsMask
12
12
  from .delayed_rewards import DelayedReward
13
13
  from .potential_shaping import PotentialShaping
14
+ from .action_randomizer import ActionRandomizer
14
15
 
15
16
  __all__ = [
16
17
  "RLEnvWrapper",
@@ -28,4 +29,5 @@ __all__ = [
28
29
  "Centralized",
29
30
  "DelayedReward",
30
31
  "PotentialShaping",
32
+ "ActionRandomizer",
31
33
  ]
@@ -0,0 +1,17 @@
1
+ from .rlenv_wrapper import RLEnvWrapper, AS, MARLEnv
2
+ import numpy as np
3
+
4
+
5
+ class ActionRandomizer(RLEnvWrapper[AS]):
6
+ def __init__(self, env: MARLEnv[AS], p: float):
7
+ super().__init__(env)
8
+ self.p = p
9
+
10
+ def step(self, action):
11
+ if np.random.rand() < self.p:
12
+ action = self.action_space.sample()
13
+ return super().step(action)
14
+
15
+ def seed(self, seed_value: int):
16
+ np.random.seed(seed_value)
17
+ super().seed(seed_value)