multi-agent-rlenv 3.3.7__tar.gz → 3.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/PKG-INFO +1 -1
  2. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/__init__.py +1 -1
  3. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/adapters/overcooked_adapter.py +37 -16
  4. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/models/env.py +8 -3
  5. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/models/episode.py +2 -2
  6. multi_agent_rlenv-3.4.0/src/marlenv/utils/__init__.py +10 -0
  7. multi_agent_rlenv-3.4.0/src/marlenv/utils/schedule.py +281 -0
  8. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/tests/test_adapters.py +27 -9
  9. multi_agent_rlenv-3.4.0/tests/test_schedules.py +155 -0
  10. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/tests/test_serialization.py +50 -0
  11. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/.github/workflows/ci.yaml +0 -0
  12. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/.github/workflows/docs.yaml +0 -0
  13. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/.gitignore +0 -0
  14. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/LICENSE +0 -0
  15. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/README.md +0 -0
  16. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/pyproject.toml +0 -0
  17. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/adapters/__init__.py +0 -0
  18. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/adapters/gym_adapter.py +0 -0
  19. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/adapters/pettingzoo_adapter.py +0 -0
  20. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/adapters/pymarl_adapter.py +0 -0
  21. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/adapters/smac_adapter.py +0 -0
  22. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/env_builder.py +0 -0
  23. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/env_pool.py +0 -0
  24. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/exceptions.py +0 -0
  25. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/mock_env.py +0 -0
  26. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/models/__init__.py +0 -0
  27. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/models/observation.py +0 -0
  28. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/models/spaces.py +0 -0
  29. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/models/state.py +0 -0
  30. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/models/step.py +0 -0
  31. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/models/transition.py +0 -0
  32. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/py.typed +0 -0
  33. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/__init__.py +0 -0
  34. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/agent_id_wrapper.py +0 -0
  35. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/available_actions_mask.py +0 -0
  36. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/available_actions_wrapper.py +0 -0
  37. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/blind_wrapper.py +0 -0
  38. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/centralised.py +0 -0
  39. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/delayed_rewards.py +0 -0
  40. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/last_action_wrapper.py +0 -0
  41. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/paddings.py +0 -0
  42. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/penalty_wrapper.py +0 -0
  43. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/rlenv_wrapper.py +0 -0
  44. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/time_limit.py +0 -0
  45. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/src/marlenv/wrappers/video_recorder.py +0 -0
  46. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/tests/__init__.py +0 -0
  47. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/tests/test_episode.py +0 -0
  48. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/tests/test_models.py +0 -0
  49. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/tests/test_pool.py +0 -0
  50. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/tests/test_spaces.py +0 -0
  51. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/tests/test_wrappers.py +0 -0
  52. {multi_agent_rlenv-3.3.7 → multi_agent_rlenv-3.4.0}/tests/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.3.7
3
+ Version: 3.4.0
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
@@ -62,7 +62,7 @@ print(env.extras_shape) # (1, )
62
62
  If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
63
63
  """
64
64
 
65
- __version__ = "3.3.7"
65
+ __version__ = "3.4.0"
66
66
 
67
67
  from . import models
68
68
  from . import wrappers
@@ -1,14 +1,14 @@
1
1
  import sys
2
2
  from dataclasses import dataclass
3
- from typing import Literal, Sequence
3
+ from typing import Literal, Sequence, Optional
4
4
  from copy import deepcopy
5
- from time import time
6
5
 
7
6
  import cv2
8
7
  import numpy as np
9
8
  import numpy.typing as npt
10
9
  import pygame
11
10
  from marlenv.models import ContinuousSpace, DiscreteActionSpace, MARLEnv, Observation, State, Step
11
+ from marlenv.utils import Schedule
12
12
 
13
13
  from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
14
14
  from overcooked_ai_py.mdp.overcooked_mdp import Action, OvercookedGridworld, OvercookedState
@@ -18,10 +18,17 @@ from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
18
18
  @dataclass
19
19
  class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
20
20
  horizon: int
21
- reward_shaping: bool
21
+ shaping_factor: Schedule
22
22
 
23
- def __init__(self, oenv: OvercookedEnv, reward_shaping: bool = True):
24
- self.reward_shaping = reward_shaping
23
+ def __init__(
24
+ self,
25
+ oenv: OvercookedEnv,
26
+ shaping_factor: float | Schedule = 1.0,
27
+ name_suffix: Optional[str] = None,
28
+ ):
29
+ if isinstance(shaping_factor, (int, float)):
30
+ shaping_factor = Schedule.constant(shaping_factor)
31
+ self.shaping_factor = shaping_factor
25
32
  self._oenv = oenv
26
33
  assert isinstance(oenv.mdp, OvercookedGridworld)
27
34
  self._mdp = oenv.mdp
@@ -43,6 +50,8 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
43
50
  reward_space=ContinuousSpace.from_shape(1),
44
51
  )
45
52
  self.horizon = int(self._oenv.horizon)
53
+ if name_suffix is not None:
54
+ self.name = f"{self.name}-{name_suffix}"
46
55
 
47
56
  @property
48
57
  def state(self) -> OvercookedState:
@@ -87,10 +96,11 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
87
96
  return np.array(available_actions, dtype=np.bool)
88
97
 
89
98
  def step(self, actions: Sequence[int] | npt.NDArray[np.int32 | np.int64]) -> Step:
99
+ self.shaping_factor.update()
90
100
  actions = [Action.ALL_ACTIONS[a] for a in actions]
91
101
  _, reward, done, info = self._oenv.step(actions, display_phi=True)
92
- if self.reward_shaping:
93
- reward += sum(info["shaped_r_by_agent"])
102
+
103
+ reward += sum(info["shaped_r_by_agent"]) * self.shaping_factor
94
104
  return Step(
95
105
  obs=self.get_observation(),
96
106
  state=self.get_state(),
@@ -104,19 +114,25 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
104
114
  self._oenv.reset()
105
115
  return self.get_observation(), self.get_state()
106
116
 
107
- def __deepcopy__(self, memo: dict):
117
+ def __deepcopy__(self, _):
118
+ """
119
+ Note: a specific implementation is needed because `pygame.font.Font` objects are not deep-copiable by default.
120
+ """
108
121
  mdp = deepcopy(self._mdp)
109
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=self.horizon))
122
+ copy = Overcooked(OvercookedEnv.from_mdp(mdp, horizon=self.horizon), deepcopy(self.shaping_factor))
123
+ copy.name = self.name
124
+ return copy
110
125
 
111
126
  def __getstate__(self):
112
- return {"horizon": self.horizon, "mdp": self._mdp}
127
+ return {"horizon": self.horizon, "mdp": self._mdp, "name": self.name, "schedule": self.shaping_factor}
113
128
 
114
129
  def __setstate__(self, state: dict):
115
130
  from overcooked_ai_py.mdp.overcooked_mdp import Recipe
116
131
 
117
132
  mdp = state["mdp"]
118
133
  Recipe.configure(mdp.recipe_config)
119
- self.__init__(OvercookedEnv.from_mdp(state["mdp"], horizon=state["horizon"]))
134
+ self.__init__(OvercookedEnv.from_mdp(state["mdp"], horizon=state["horizon"]), shaping_factor=state["schedule"])
135
+ self.name = state["name"]
120
136
 
121
137
  def get_image(self):
122
138
  rewards_dict = {} # dictionary of details you want rendered in the UI
@@ -190,16 +206,17 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
190
206
  "you_shall_not_pass",
191
207
  ],
192
208
  horizon: int = 400,
193
- reward_shaping: bool = True,
209
+ reward_shaping_factor: float | Schedule = 1.0,
194
210
  ):
195
211
  mdp = OvercookedGridworld.from_layout_name(layout)
196
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon), reward_shaping=reward_shaping)
212
+ return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), reward_shaping_factor, layout)
197
213
 
198
214
  @staticmethod
199
215
  def from_grid(
200
216
  grid: Sequence[Sequence[Literal["S", "P", "X", "O", "D", "T", "1", "2", " "] | str]],
201
217
  horizon: int = 400,
202
- reward_shaping: bool = True,
218
+ shaping_factor: float | Schedule = 1.0,
219
+ layout_name: Optional[str] = None,
203
220
  ):
204
221
  """
205
222
  Create an Overcooked environment from a grid layout where
@@ -212,10 +229,14 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
212
229
  - 1 is a player 1 starting location
213
230
  - 2 is a player 2 starting location
214
231
  - ' ' is a walkable space
232
+
233
+ If provided, `custom_name` is added to the environment name.
215
234
  """
216
235
  # It is necessary to add an explicit layout name because Overcooked saves some files under this
217
236
  # name. By default the name is a concatenation of the grid elements, which may include characters
218
237
  # such as white spaces, pipes ('|') and square brackets ('[' and ']') that are invalid Windows file paths.
219
- layout_name = str(time())
238
+ if layout_name is None:
239
+ layout_name = "custom-layout"
220
240
  mdp = OvercookedGridworld.from_grid(grid, base_layout_params={"layout_name": layout_name})
221
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon), reward_shaping=reward_shaping)
241
+
242
+ return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), shaping_factor, layout_name)
@@ -1,7 +1,7 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from dataclasses import dataclass
3
3
  from itertools import product
4
- from typing import Generic, Optional, Sequence
4
+ from typing import Any, Generic, Optional, Sequence
5
5
 
6
6
  import cv2
7
7
  import numpy as np
@@ -13,8 +13,8 @@ from .spaces import ActionSpace, ContinuousSpace, Space
13
13
  from .state import State
14
14
  from .step import Step
15
15
 
16
- ActionType = TypeVar("ActionType", default=npt.NDArray)
17
- ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=ActionSpace)
16
+ ActionType = TypeVar("ActionType", default=Any)
17
+ ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=Any)
18
18
 
19
19
 
20
20
  @dataclass
@@ -108,6 +108,11 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
108
108
  """Whether the environment is multi-objective."""
109
109
  return self.reward_space.size > 1
110
110
 
111
+ @property
112
+ def n_objectives(self) -> int:
113
+ """The number of objectives in the environment."""
114
+ return self.reward_space.size
115
+
111
116
  def sample_action(self) -> ActionType:
112
117
  """Sample an available action from the action space."""
113
118
  return self.action_space.sample(self.available_actions()) # type: ignore
@@ -179,9 +179,9 @@ class Episode(Generic[A]):
179
179
  @cached_property
180
180
  def dones(self):
181
181
  """The done flags for each transition"""
182
- dones = np.zeros_like(self.rewards, dtype=np.float32)
182
+ dones = np.zeros_like(self.rewards, dtype=np.bool)
183
183
  if self.is_done:
184
- dones[self.episode_len - 1 :] = 1.0
184
+ dones[self.episode_len - 1 :] = True
185
185
  return dones
186
186
 
187
187
  @property
@@ -0,0 +1,10 @@
1
+ from .schedule import Schedule, MultiSchedule, RoundedSchedule, LinearSchedule, ExpSchedule
2
+
3
+
4
+ __all__ = [
5
+ "Schedule",
6
+ "LinearSchedule",
7
+ "ExpSchedule",
8
+ "MultiSchedule",
9
+ "RoundedSchedule",
10
+ ]
@@ -0,0 +1,281 @@
1
+ from abc import abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Optional, TypeVar
4
+
5
+ T = TypeVar("T")
6
+
7
+
8
+ @dataclass
9
+ class Schedule:
10
+ """
11
+ Schedules the value of a varaible over time.
12
+ """
13
+
14
+ name: str
15
+ start_value: float
16
+ end_value: float
17
+ _t: int
18
+ n_steps: int
19
+
20
+ def __init__(self, start_value: float, end_value: float, n_steps: int):
21
+ self.start_value = start_value
22
+ self.end_value = end_value
23
+ self.n_steps = n_steps
24
+ self.name = self.__class__.__name__
25
+ self._t = 0
26
+ self._current_value = self.start_value
27
+
28
+ def update(self, step: Optional[int] = None):
29
+ """Update the value of the schedule. Force a step if given."""
30
+ if step is not None:
31
+ self._t = step
32
+ else:
33
+ self._t += 1
34
+ if self._t >= self.n_steps:
35
+ self._current_value = self.end_value
36
+ else:
37
+ self._current_value = self._compute()
38
+
39
+ @abstractmethod
40
+ def _compute(self) -> float:
41
+ """Compute the value of the schedule"""
42
+
43
+ @property
44
+ def value(self) -> float:
45
+ """Returns the current value of the schedule"""
46
+ return self._current_value
47
+
48
+ @staticmethod
49
+ def constant(value: float):
50
+ return ConstantSchedule(value)
51
+
52
+ @staticmethod
53
+ def linear(start_value: float, end_value: float, n_steps: int):
54
+ return LinearSchedule(start_value, end_value, n_steps)
55
+
56
+ @staticmethod
57
+ def exp(start_value: float, end_value: float, n_steps: int):
58
+ return ExpSchedule(start_value, end_value, n_steps)
59
+
60
+ @staticmethod
61
+ def arbitrary(func: Callable[[int], float], n_steps: Optional[int] = None):
62
+ if n_steps is None:
63
+ n_steps = 0
64
+ return ArbitrarySchedule(func, n_steps)
65
+
66
+ def rounded(self, n_digits: int = 0) -> "RoundedSchedule":
67
+ return RoundedSchedule(self, n_digits)
68
+
69
+ # Operator overloading
70
+ def __mul__(self, other: T) -> T:
71
+ return self.value * other # type: ignore
72
+
73
+ def __rmul__(self, other: T) -> T:
74
+ return self.value * other # type: ignore
75
+
76
+ def __pow__(self, exp: float) -> float:
77
+ return self.value**exp
78
+
79
+ def __rpow__(self, other: T) -> T:
80
+ return other**self.value # type: ignore
81
+
82
+ def __add__(self, other: T) -> T:
83
+ return self.value + other # type: ignore
84
+
85
+ def __radd__(self, other: T) -> T:
86
+ return self.value + other # type: ignore
87
+
88
+ def __neg__(self):
89
+ return -self.value
90
+
91
+ def __pos__(self):
92
+ return +self.value
93
+
94
+ def __sub__(self, other: T) -> T:
95
+ return self.value - other # type: ignore
96
+
97
+ def __rsub__(self, other: T) -> T:
98
+ return other - self.value # type: ignore
99
+
100
+ def __div__(self, other: T) -> T:
101
+ return self.value // other # type: ignore
102
+
103
+ def __rdiv__(self, other: T) -> T:
104
+ return other // self.value # type: ignore
105
+
106
+ def __truediv__(self, other: T) -> T:
107
+ return self.value / other # type: ignore
108
+
109
+ def __rtruediv__(self, other: T) -> T:
110
+ return other / self.value # type: ignore
111
+
112
+ def __lt__(self, other) -> bool:
113
+ return self.value < other
114
+
115
+ def __le__(self, other) -> bool:
116
+ return self.value <= other
117
+
118
+ def __gt__(self, other) -> bool:
119
+ return self.value > other
120
+
121
+ def __ge__(self, other) -> bool:
122
+ return self.value >= other
123
+
124
+ def __eq__(self, other) -> bool:
125
+ if isinstance(other, Schedule):
126
+ if self.start_value != other.start_value:
127
+ return False
128
+ if self.end_value != other.end_value:
129
+ return False
130
+ if self.n_steps != other.n_steps:
131
+ return False
132
+ if type(self) is not type(other):
133
+ return False
134
+ return self.value == other
135
+
136
+ def __ne__(self, other) -> bool:
137
+ return not (self.__eq__(other))
138
+
139
+ def __float__(self):
140
+ return self.value
141
+
142
+ def __int__(self) -> int:
143
+ return int(self.value)
144
+
145
+
146
+ @dataclass(eq=False)
147
+ class LinearSchedule(Schedule):
148
+ a: float
149
+ b: float
150
+
151
+ def __init__(self, start_value: float, end_value: float, n_steps: int):
152
+ super().__init__(start_value, end_value, n_steps)
153
+ self._current_value = self.start_value
154
+ # y = ax + b
155
+ self.a = (self.end_value - self.start_value) / self.n_steps
156
+ self.b = self.start_value
157
+
158
+ def _compute(self):
159
+ return self.a * (self._t) + self.b
160
+
161
+ @property
162
+ def value(self) -> float:
163
+ return self._current_value
164
+
165
+
166
+ @dataclass(eq=False)
167
+ class ExpSchedule(Schedule):
168
+ """Exponential schedule. After n_steps, the value will be min_value.
169
+
170
+ Update formula is next_value = start_value * (min_value / start_value) ** (step / (n - 1))
171
+ """
172
+
173
+ n_steps: int
174
+ base: float
175
+ last_update_step: int
176
+
177
+ def __init__(self, start_value: float, min_value: float, n_steps: int):
178
+ super().__init__(start_value, min_value, n_steps)
179
+ self.base = self.end_value / self.start_value
180
+ self.last_update_step = self.n_steps - 1
181
+
182
+ def _compute(self):
183
+ return self.start_value * (self.base) ** (self._t / (self.n_steps - 1))
184
+
185
+ @property
186
+ def value(self) -> float:
187
+ return self._current_value
188
+
189
+
190
+ @dataclass(eq=False)
191
+ class ConstantSchedule(Schedule):
192
+ def __init__(self, value: float):
193
+ super().__init__(value, value, 0)
194
+ self._value = value
195
+
196
+ def update(self, step=None):
197
+ return
198
+
199
+ @property
200
+ def value(self) -> float:
201
+ return self._value
202
+
203
+
204
+ @dataclass(eq=False)
205
+ class RoundedSchedule(Schedule):
206
+ def __init__(self, schedule: Schedule, n_digits: int):
207
+ super().__init__(schedule.start_value, schedule.end_value, schedule.n_steps)
208
+ self.schedule = schedule
209
+ self.n_digits = n_digits
210
+
211
+ def update(self, step: int | None = None):
212
+ return self.schedule.update(step)
213
+
214
+ def _compute(self) -> float:
215
+ return self.schedule._compute()
216
+
217
+ @property
218
+ def value(self) -> float:
219
+ return round(self.schedule.value, self.n_digits)
220
+
221
+
222
+ @dataclass(eq=False)
223
+ class MultiSchedule(Schedule):
224
+ def __init__(self, schedules: dict[int, Schedule]):
225
+ ordered_schedules, ordered_steps = MultiSchedule._verify(schedules)
226
+ n_steps = ordered_steps[-1] + ordered_schedules[-1].n_steps
227
+ super().__init__(ordered_schedules[0].start_value, ordered_schedules[-1].end_value, n_steps)
228
+ self.schedules = iter(ordered_schedules)
229
+ self.current_schedule = next(self.schedules)
230
+ self.offset = 0
231
+ self.current_end = ordered_steps[1]
232
+
233
+ @staticmethod
234
+ def _verify(schedules: dict[int, Schedule]):
235
+ sorted_steps = sorted(schedules.keys())
236
+ sorted_schedules = [schedules[t] for t in sorted_steps]
237
+ if sorted_steps[0] != 0:
238
+ raise ValueError("First schedule must start at t=0")
239
+ current_step = 0
240
+ for i in range(len(sorted_steps)):
241
+ # Artificially set the end step of ConstantSchedules to the next step
242
+ if isinstance(sorted_schedules[i], ConstantSchedule):
243
+ if i + 1 < len(sorted_steps):
244
+ sorted_schedules[i].n_steps = sorted_steps[i + 1] - sorted_steps[i]
245
+ if sorted_steps[i] != current_step:
246
+ raise ValueError(f"Schedules are not contiguous at t={current_step}")
247
+ current_step += sorted_schedules[i].n_steps
248
+ return sorted_schedules, sorted_steps
249
+
250
+ def update(self, step: int | None = None):
251
+ if step is not None:
252
+ raise NotImplementedError("Cannot update MultiSchedule with a specific step")
253
+ super().update(step)
254
+ # If we reach the end of the current schedule, update to the next one
255
+ # except if we are at the end.
256
+ if self._t == self.current_end and self._t < self.n_steps:
257
+ self.current_schedule = next(self.schedules)
258
+ self.offset = self._t
259
+ self.current_end += self.current_schedule.n_steps
260
+ self.current_schedule.update(self.relative_step)
261
+
262
+ @property
263
+ def relative_step(self):
264
+ return self._t - self.offset
265
+
266
+ def _compute(self) -> float:
267
+ return self.current_schedule._compute()
268
+
269
+ @property
270
+ def value(self):
271
+ return self.current_schedule.value
272
+
273
+
274
+ @dataclass(eq=False)
275
+ class ArbitrarySchedule(Schedule):
276
+ def __init__(self, fn: Callable[[int], float], n_steps: int):
277
+ super().__init__(fn(0), fn(n_steps), n_steps)
278
+ self._func = fn
279
+
280
+ def _compute(self) -> float:
281
+ return self._func(self._t)
@@ -1,5 +1,3 @@
1
- from importlib.util import find_spec
2
-
3
1
  import numpy as np
4
2
  import pytest
5
3
 
@@ -7,12 +5,12 @@ import marlenv
7
5
  from marlenv import ContinuousActionSpace, DiscreteActionSpace, DiscreteMockEnv, MARLEnv, Observation, State
8
6
  from marlenv.adapters import PymarlAdapter
9
7
 
10
- skip_gym = find_spec("gymnasium") is None
11
- skip_pettingzoo = find_spec("pettingzoo") is None
12
- skip_smac = find_spec("smac") is None
8
+ skip_gym = not marlenv.adapters.HAS_GYM
9
+ skip_pettingzoo = not marlenv.adapters.HAS_PETTINGZOO
10
+ skip_smac = not marlenv.adapters.HAS_SMAC
13
11
  # Check for "overcooked_ai_py.mdp" specifically because after uninstalling, the package
14
12
  # can still be found because of some remaining .pkl file.
15
- skip_overcooked = find_spec("overcooked_ai_py.mdp") is None
13
+ skip_overcooked = not marlenv.adapters.HAS_OVERCOOKED
16
14
 
17
15
 
18
16
  @pytest.mark.skipif(skip_gym, reason="Gymnasium is not installed")
@@ -188,9 +186,7 @@ def test_overcooked_shaping():
188
186
  from marlenv.adapters import Overcooked
189
187
 
190
188
  UP = 0
191
- DOWN = 1
192
189
  RIGHT = 2
193
- LEFT = 3
194
190
  STAY = 4
195
191
  INTERACT = 5
196
192
  grid = [
@@ -201,7 +197,7 @@ def test_overcooked_shaping():
201
197
  ["X", "X", "X", "X", "X"],
202
198
  ]
203
199
 
204
- env = Overcooked.from_grid(grid, reward_shaping=True)
200
+ env = Overcooked.from_grid(grid)
205
201
  env.reset()
206
202
  actions_rewards = [
207
203
  ([UP, STAY], False),
@@ -216,6 +212,28 @@ def test_overcooked_shaping():
216
212
  assert step.reward.item() > 0
217
213
 
218
214
 
215
+ @pytest.mark.skipif(skip_overcooked, reason="Overcooked is not installed")
216
+ def test_overcooked_name():
217
+ from marlenv.adapters import Overcooked
218
+
219
+ grid = [
220
+ ["X", "X", "X", "D", "X"],
221
+ ["X", "O", "S", "2", "X"],
222
+ ["X", "1", "P", " ", "X"],
223
+ ["X", "T", "S", " ", "X"],
224
+ ["X", "X", "X", "X", "X"],
225
+ ]
226
+
227
+ env = Overcooked.from_grid(grid)
228
+ assert env.name == "Overcooked-custom-layout"
229
+
230
+ env = Overcooked.from_grid(grid, layout_name="my incredible grid")
231
+ assert env.name == "Overcooked-my incredible grid"
232
+
233
+ env = Overcooked.from_layout("asymmetric_advantages")
234
+ assert env.name == "Overcooked-asymmetric_advantages"
235
+
236
+
219
237
  def test_pymarl():
220
238
  LIMIT = 20
221
239
  N_AGENTS = 2
@@ -0,0 +1,155 @@
1
+ from marlenv.utils import Schedule, MultiSchedule
2
+
3
+
4
+ def is_close(a: float, b: float, tol: float = 1e-6) -> bool:
5
+ return abs(a - b) < tol
6
+
7
+
8
+ def test_linear_schedule_increasing():
9
+ s = Schedule.linear(0, 1, 10)
10
+ for i in range(10):
11
+ assert is_close(s.value, i / 10)
12
+ s.update()
13
+ for i in range(10):
14
+ assert s.value == 1.0
15
+ s.update()
16
+
17
+
18
+ def test_linear_schedule_decreasing():
19
+ s = Schedule.linear(0, -1, 10)
20
+ for i in range(10):
21
+ assert is_close(s.value, -i / 10)
22
+ s.update()
23
+ for i in range(10):
24
+ assert s.value == -1.0
25
+ s.update()
26
+
27
+
28
+ def test_linear_schedule_set_timestep():
29
+ s = Schedule.linear(0, 1, 10)
30
+ s.update(50)
31
+ assert is_close(s.value, 1)
32
+
33
+ s.update(0)
34
+ assert is_close(s.value, 0)
35
+
36
+ s.update(5)
37
+ assert is_close(s.value, 0.5)
38
+
39
+
40
+ def test_exp_schedule_increasing():
41
+ s = Schedule.exp(1, 16, 5)
42
+ assert is_close(s.value, 1)
43
+ s.update()
44
+ assert is_close(s.value, 2)
45
+ s.update()
46
+ assert is_close(s.value, 4)
47
+ s.update()
48
+ assert is_close(s.value, 8)
49
+ s.update()
50
+ assert is_close(s.value, 16)
51
+ for _ in range(10):
52
+ s.update()
53
+ assert is_close(s.value, 16)
54
+
55
+
56
+ def test_exp_schedule_set_timestep():
57
+ s = Schedule.exp(1, 16, 5)
58
+ s.update(50)
59
+ assert is_close(s.value, 16)
60
+
61
+ s.update(0)
62
+ assert is_close(s.value, 1)
63
+
64
+ s.update(5)
65
+ assert is_close(s.value, 16)
66
+
67
+ s.update(1)
68
+ assert is_close(s.value, 2)
69
+
70
+
71
+ def test_exp_schedule_decreasing():
72
+ s = Schedule.exp(16, 1, 5)
73
+ assert is_close(s.value, 16)
74
+ s.update()
75
+ assert is_close(s.value, 8)
76
+ s.update()
77
+ assert is_close(s.value, 4)
78
+ s.update()
79
+ assert is_close(s.value, 2)
80
+ s.update()
81
+ assert is_close(s.value, 1)
82
+ for _ in range(10):
83
+ s.update()
84
+ assert is_close(s.value, 1)
85
+
86
+
87
+ def test_const_schedule():
88
+ s = Schedule.constant(50)
89
+ for _ in range(10):
90
+ assert s.value == 50
91
+ s.update()
92
+
93
+
94
+ def test_equality_linear():
95
+ s1 = Schedule.linear(0, 1, 10)
96
+ s2 = Schedule.linear(0, 1, 10)
97
+ s3 = Schedule.linear(0, 1, 5)
98
+ assert s1 == s2
99
+ assert s1 != s3
100
+
101
+ s1.update()
102
+ assert s1 != s2
103
+ s2.update()
104
+ assert s1 == s2
105
+
106
+
107
+ def test_equality_exp():
108
+ s1 = Schedule.exp(1, 16, 5)
109
+ s2 = Schedule.exp(1, 16, 5)
110
+ s3 = Schedule.exp(1, 16, 10)
111
+ assert s1 != 5
112
+ assert s1 == 1
113
+ assert s1 == s2
114
+ assert s1 != s3
115
+
116
+ s1.update()
117
+ assert s1 != s2
118
+ assert s1 == 2
119
+ assert s2 == 1
120
+ s2.update()
121
+ assert s1 == s2
122
+ assert s2 == 2
123
+
124
+
125
+ def test_multi_schedule():
126
+ s = MultiSchedule(
127
+ {
128
+ 0: Schedule.constant(0),
129
+ 10: Schedule.linear(0, 1, 10),
130
+ 20: Schedule.exp(1, 16, 5),
131
+ }
132
+ )
133
+ expected_values = [0.0] * 10 + [i / 10 for i in range(10)] + [2**i for i in range(5)]
134
+ for i in range(25):
135
+ assert is_close(s.value, expected_values[i])
136
+ s.update()
137
+
138
+
139
+ def test_equality_const():
140
+ s1 = Schedule.constant(50)
141
+ s2 = Schedule.constant(50)
142
+ assert s1 == s2
143
+
144
+ s1.update()
145
+ assert s1 == s2
146
+ s2.update()
147
+ assert s1 == s2
148
+
149
+
150
+ def test_inequality_different_schedules():
151
+ s1 = Schedule.linear(1, 2, 10)
152
+ s2 = Schedule.exp(1, 2, 10)
153
+ s3 = Schedule.linear(1, 2, 10)
154
+ assert s1 != s2
155
+ assert not s1 != s3
@@ -7,6 +7,7 @@ from copy import deepcopy
7
7
 
8
8
  import marlenv
9
9
  from marlenv import DiscreteMockEnv, wrappers
10
+ from marlenv.utils import Schedule
10
11
 
11
12
 
12
13
  def test_registry():
@@ -199,6 +200,16 @@ def test_deepcopy_overcooked():
199
200
  assert env == env2
200
201
 
201
202
 
203
+ @pytest.mark.skipif(not marlenv.adapters.HAS_OVERCOOKED, reason="Overcooked is not installed")
204
+ def test_deepcopy_overcooked_schedule():
205
+ env = marlenv.adapters.Overcooked.from_layout("scenario4", reward_shaping_factor=Schedule.linear(1, 0, 10))
206
+ env2 = deepcopy(env)
207
+ assert env == env2
208
+
209
+ env.random_step()
210
+ assert not env == env2, "The reward shaping factor should be different"
211
+
212
+
202
213
  @pytest.mark.skipif(not marlenv.adapters.HAS_OVERCOOKED, reason="Overcooked is not installed")
203
214
  def test_pickle_overcooked():
204
215
  env = marlenv.adapters.Overcooked.from_layout("scenario1_s", horizon=60)
@@ -281,3 +292,42 @@ def test_json_serialize_pettingzoo():
281
292
  def test_json_serialize_smac():
282
293
  env = marlenv.adapters.SMAC("3m")
283
294
  serde_and_check_key_values(env)
295
+
296
+
297
+ class C:
298
+ def __call__(self, t):
299
+ return t + 1
300
+
301
+
302
+ def test_serialize_schedule():
303
+ s = Schedule.linear(0, 1, 10)
304
+ orjson.dumps(s)
305
+ b = pickle.dumps(s)
306
+ s2 = pickle.loads(b)
307
+ assert s == s2
308
+
309
+ s = Schedule.exp(1, 16, 5)
310
+ orjson.dumps(s)
311
+ b = pickle.dumps(s)
312
+ s2 = pickle.loads(b)
313
+ assert s == s2
314
+
315
+ s = Schedule.constant(50)
316
+ orjson.dumps(s)
317
+ b = pickle.dumps(s)
318
+ s2 = pickle.loads(b)
319
+ assert s == s2
320
+
321
+ s = Schedule.arbitrary(lambda t: t + 1)
322
+ b = orjson.dumps(s)
323
+ try:
324
+ pickle.dumps(s)
325
+ assert False, "Should not be able to pickle arbitrary schedules because of the callable lambda"
326
+ except AttributeError:
327
+ pass
328
+
329
+ s = Schedule.arbitrary(C())
330
+ orjson.dumps(s)
331
+ b = pickle.dumps(s)
332
+ s2 = pickle.loads(b)
333
+ assert s == s2