multi-agent-rlenv 3.3.7__py3-none-any.whl → 3.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- marlenv/__init__.py +1 -1
- marlenv/adapters/overcooked_adapter.py +37 -16
- marlenv/models/env.py +8 -3
- marlenv/models/episode.py +2 -2
- marlenv/utils/__init__.py +10 -0
- marlenv/utils/schedule.py +281 -0
- {multi_agent_rlenv-3.3.7.dist-info → multi_agent_rlenv-3.4.0.dist-info}/METADATA +1 -1
- {multi_agent_rlenv-3.3.7.dist-info → multi_agent_rlenv-3.4.0.dist-info}/RECORD +10 -8
- {multi_agent_rlenv-3.3.7.dist-info → multi_agent_rlenv-3.4.0.dist-info}/WHEEL +0 -0
- {multi_agent_rlenv-3.3.7.dist-info → multi_agent_rlenv-3.4.0.dist-info}/licenses/LICENSE +0 -0
marlenv/__init__.py
CHANGED
|
@@ -62,7 +62,7 @@ print(env.extras_shape) # (1, )
|
|
|
62
62
|
If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
|
|
63
63
|
"""
|
|
64
64
|
|
|
65
|
-
__version__ = "3.
|
|
65
|
+
__version__ = "3.4.0"
|
|
66
66
|
|
|
67
67
|
from . import models
|
|
68
68
|
from . import wrappers
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Literal, Sequence
|
|
3
|
+
from typing import Literal, Sequence, Optional
|
|
4
4
|
from copy import deepcopy
|
|
5
|
-
from time import time
|
|
6
5
|
|
|
7
6
|
import cv2
|
|
8
7
|
import numpy as np
|
|
9
8
|
import numpy.typing as npt
|
|
10
9
|
import pygame
|
|
11
10
|
from marlenv.models import ContinuousSpace, DiscreteActionSpace, MARLEnv, Observation, State, Step
|
|
11
|
+
from marlenv.utils import Schedule
|
|
12
12
|
|
|
13
13
|
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
|
|
14
14
|
from overcooked_ai_py.mdp.overcooked_mdp import Action, OvercookedGridworld, OvercookedState
|
|
@@ -18,10 +18,17 @@ from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
|
|
|
18
18
|
@dataclass
|
|
19
19
|
class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
20
20
|
horizon: int
|
|
21
|
-
|
|
21
|
+
shaping_factor: Schedule
|
|
22
22
|
|
|
23
|
-
def __init__(
|
|
24
|
-
self
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
oenv: OvercookedEnv,
|
|
26
|
+
shaping_factor: float | Schedule = 1.0,
|
|
27
|
+
name_suffix: Optional[str] = None,
|
|
28
|
+
):
|
|
29
|
+
if isinstance(shaping_factor, (int, float)):
|
|
30
|
+
shaping_factor = Schedule.constant(shaping_factor)
|
|
31
|
+
self.shaping_factor = shaping_factor
|
|
25
32
|
self._oenv = oenv
|
|
26
33
|
assert isinstance(oenv.mdp, OvercookedGridworld)
|
|
27
34
|
self._mdp = oenv.mdp
|
|
@@ -43,6 +50,8 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
43
50
|
reward_space=ContinuousSpace.from_shape(1),
|
|
44
51
|
)
|
|
45
52
|
self.horizon = int(self._oenv.horizon)
|
|
53
|
+
if name_suffix is not None:
|
|
54
|
+
self.name = f"{self.name}-{name_suffix}"
|
|
46
55
|
|
|
47
56
|
@property
|
|
48
57
|
def state(self) -> OvercookedState:
|
|
@@ -87,10 +96,11 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
87
96
|
return np.array(available_actions, dtype=np.bool)
|
|
88
97
|
|
|
89
98
|
def step(self, actions: Sequence[int] | npt.NDArray[np.int32 | np.int64]) -> Step:
|
|
99
|
+
self.shaping_factor.update()
|
|
90
100
|
actions = [Action.ALL_ACTIONS[a] for a in actions]
|
|
91
101
|
_, reward, done, info = self._oenv.step(actions, display_phi=True)
|
|
92
|
-
|
|
93
|
-
|
|
102
|
+
|
|
103
|
+
reward += sum(info["shaped_r_by_agent"]) * self.shaping_factor
|
|
94
104
|
return Step(
|
|
95
105
|
obs=self.get_observation(),
|
|
96
106
|
state=self.get_state(),
|
|
@@ -104,19 +114,25 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
104
114
|
self._oenv.reset()
|
|
105
115
|
return self.get_observation(), self.get_state()
|
|
106
116
|
|
|
107
|
-
def __deepcopy__(self,
|
|
117
|
+
def __deepcopy__(self, _):
|
|
118
|
+
"""
|
|
119
|
+
Note: a specific implementation is needed because `pygame.font.Font` objects are not deep-copiable by default.
|
|
120
|
+
"""
|
|
108
121
|
mdp = deepcopy(self._mdp)
|
|
109
|
-
|
|
122
|
+
copy = Overcooked(OvercookedEnv.from_mdp(mdp, horizon=self.horizon), deepcopy(self.shaping_factor))
|
|
123
|
+
copy.name = self.name
|
|
124
|
+
return copy
|
|
110
125
|
|
|
111
126
|
def __getstate__(self):
|
|
112
|
-
return {"horizon": self.horizon, "mdp": self._mdp}
|
|
127
|
+
return {"horizon": self.horizon, "mdp": self._mdp, "name": self.name, "schedule": self.shaping_factor}
|
|
113
128
|
|
|
114
129
|
def __setstate__(self, state: dict):
|
|
115
130
|
from overcooked_ai_py.mdp.overcooked_mdp import Recipe
|
|
116
131
|
|
|
117
132
|
mdp = state["mdp"]
|
|
118
133
|
Recipe.configure(mdp.recipe_config)
|
|
119
|
-
self.__init__(OvercookedEnv.from_mdp(state["mdp"], horizon=state["horizon"]))
|
|
134
|
+
self.__init__(OvercookedEnv.from_mdp(state["mdp"], horizon=state["horizon"]), shaping_factor=state["schedule"])
|
|
135
|
+
self.name = state["name"]
|
|
120
136
|
|
|
121
137
|
def get_image(self):
|
|
122
138
|
rewards_dict = {} # dictionary of details you want rendered in the UI
|
|
@@ -190,16 +206,17 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
190
206
|
"you_shall_not_pass",
|
|
191
207
|
],
|
|
192
208
|
horizon: int = 400,
|
|
193
|
-
|
|
209
|
+
reward_shaping_factor: float | Schedule = 1.0,
|
|
194
210
|
):
|
|
195
211
|
mdp = OvercookedGridworld.from_layout_name(layout)
|
|
196
|
-
return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon
|
|
212
|
+
return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), reward_shaping_factor, layout)
|
|
197
213
|
|
|
198
214
|
@staticmethod
|
|
199
215
|
def from_grid(
|
|
200
216
|
grid: Sequence[Sequence[Literal["S", "P", "X", "O", "D", "T", "1", "2", " "] | str]],
|
|
201
217
|
horizon: int = 400,
|
|
202
|
-
|
|
218
|
+
shaping_factor: float | Schedule = 1.0,
|
|
219
|
+
layout_name: Optional[str] = None,
|
|
203
220
|
):
|
|
204
221
|
"""
|
|
205
222
|
Create an Overcooked environment from a grid layout where
|
|
@@ -212,10 +229,14 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
212
229
|
- 1 is a player 1 starting location
|
|
213
230
|
- 2 is a player 2 starting location
|
|
214
231
|
- ' ' is a walkable space
|
|
232
|
+
|
|
233
|
+
If provided, `custom_name` is added to the environment name.
|
|
215
234
|
"""
|
|
216
235
|
# It is necessary to add an explicit layout name because Overcooked saves some files under this
|
|
217
236
|
# name. By default the name is a concatenation of the grid elements, which may include characters
|
|
218
237
|
# such as white spaces, pipes ('|') and square brackets ('[' and ']') that are invalid Windows file paths.
|
|
219
|
-
layout_name
|
|
238
|
+
if layout_name is None:
|
|
239
|
+
layout_name = "custom-layout"
|
|
220
240
|
mdp = OvercookedGridworld.from_grid(grid, base_layout_params={"layout_name": layout_name})
|
|
221
|
-
|
|
241
|
+
|
|
242
|
+
return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), shaping_factor, layout_name)
|
marlenv/models/env.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from itertools import product
|
|
4
|
-
from typing import Generic, Optional, Sequence
|
|
4
|
+
from typing import Any, Generic, Optional, Sequence
|
|
5
5
|
|
|
6
6
|
import cv2
|
|
7
7
|
import numpy as np
|
|
@@ -13,8 +13,8 @@ from .spaces import ActionSpace, ContinuousSpace, Space
|
|
|
13
13
|
from .state import State
|
|
14
14
|
from .step import Step
|
|
15
15
|
|
|
16
|
-
ActionType = TypeVar("ActionType", default=
|
|
17
|
-
ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=
|
|
16
|
+
ActionType = TypeVar("ActionType", default=Any)
|
|
17
|
+
ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=Any)
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@dataclass
|
|
@@ -108,6 +108,11 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
|
|
|
108
108
|
"""Whether the environment is multi-objective."""
|
|
109
109
|
return self.reward_space.size > 1
|
|
110
110
|
|
|
111
|
+
@property
|
|
112
|
+
def n_objectives(self) -> int:
|
|
113
|
+
"""The number of objectives in the environment."""
|
|
114
|
+
return self.reward_space.size
|
|
115
|
+
|
|
111
116
|
def sample_action(self) -> ActionType:
|
|
112
117
|
"""Sample an available action from the action space."""
|
|
113
118
|
return self.action_space.sample(self.available_actions()) # type: ignore
|
marlenv/models/episode.py
CHANGED
|
@@ -179,9 +179,9 @@ class Episode(Generic[A]):
|
|
|
179
179
|
@cached_property
|
|
180
180
|
def dones(self):
|
|
181
181
|
"""The done flags for each transition"""
|
|
182
|
-
dones = np.zeros_like(self.rewards, dtype=np.
|
|
182
|
+
dones = np.zeros_like(self.rewards, dtype=np.bool)
|
|
183
183
|
if self.is_done:
|
|
184
|
-
dones[self.episode_len - 1 :] =
|
|
184
|
+
dones[self.episode_len - 1 :] = True
|
|
185
185
|
return dones
|
|
186
186
|
|
|
187
187
|
@property
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Callable, Optional, TypeVar
|
|
4
|
+
|
|
5
|
+
T = TypeVar("T")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Schedule:
|
|
10
|
+
"""
|
|
11
|
+
Schedules the value of a varaible over time.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
name: str
|
|
15
|
+
start_value: float
|
|
16
|
+
end_value: float
|
|
17
|
+
_t: int
|
|
18
|
+
n_steps: int
|
|
19
|
+
|
|
20
|
+
def __init__(self, start_value: float, end_value: float, n_steps: int):
|
|
21
|
+
self.start_value = start_value
|
|
22
|
+
self.end_value = end_value
|
|
23
|
+
self.n_steps = n_steps
|
|
24
|
+
self.name = self.__class__.__name__
|
|
25
|
+
self._t = 0
|
|
26
|
+
self._current_value = self.start_value
|
|
27
|
+
|
|
28
|
+
def update(self, step: Optional[int] = None):
|
|
29
|
+
"""Update the value of the schedule. Force a step if given."""
|
|
30
|
+
if step is not None:
|
|
31
|
+
self._t = step
|
|
32
|
+
else:
|
|
33
|
+
self._t += 1
|
|
34
|
+
if self._t >= self.n_steps:
|
|
35
|
+
self._current_value = self.end_value
|
|
36
|
+
else:
|
|
37
|
+
self._current_value = self._compute()
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def _compute(self) -> float:
|
|
41
|
+
"""Compute the value of the schedule"""
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def value(self) -> float:
|
|
45
|
+
"""Returns the current value of the schedule"""
|
|
46
|
+
return self._current_value
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def constant(value: float):
|
|
50
|
+
return ConstantSchedule(value)
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def linear(start_value: float, end_value: float, n_steps: int):
|
|
54
|
+
return LinearSchedule(start_value, end_value, n_steps)
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def exp(start_value: float, end_value: float, n_steps: int):
|
|
58
|
+
return ExpSchedule(start_value, end_value, n_steps)
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def arbitrary(func: Callable[[int], float], n_steps: Optional[int] = None):
|
|
62
|
+
if n_steps is None:
|
|
63
|
+
n_steps = 0
|
|
64
|
+
return ArbitrarySchedule(func, n_steps)
|
|
65
|
+
|
|
66
|
+
def rounded(self, n_digits: int = 0) -> "RoundedSchedule":
|
|
67
|
+
return RoundedSchedule(self, n_digits)
|
|
68
|
+
|
|
69
|
+
# Operator overloading
|
|
70
|
+
def __mul__(self, other: T) -> T:
|
|
71
|
+
return self.value * other # type: ignore
|
|
72
|
+
|
|
73
|
+
def __rmul__(self, other: T) -> T:
|
|
74
|
+
return self.value * other # type: ignore
|
|
75
|
+
|
|
76
|
+
def __pow__(self, exp: float) -> float:
|
|
77
|
+
return self.value**exp
|
|
78
|
+
|
|
79
|
+
def __rpow__(self, other: T) -> T:
|
|
80
|
+
return other**self.value # type: ignore
|
|
81
|
+
|
|
82
|
+
def __add__(self, other: T) -> T:
|
|
83
|
+
return self.value + other # type: ignore
|
|
84
|
+
|
|
85
|
+
def __radd__(self, other: T) -> T:
|
|
86
|
+
return self.value + other # type: ignore
|
|
87
|
+
|
|
88
|
+
def __neg__(self):
|
|
89
|
+
return -self.value
|
|
90
|
+
|
|
91
|
+
def __pos__(self):
|
|
92
|
+
return +self.value
|
|
93
|
+
|
|
94
|
+
def __sub__(self, other: T) -> T:
|
|
95
|
+
return self.value - other # type: ignore
|
|
96
|
+
|
|
97
|
+
def __rsub__(self, other: T) -> T:
|
|
98
|
+
return other - self.value # type: ignore
|
|
99
|
+
|
|
100
|
+
def __div__(self, other: T) -> T:
|
|
101
|
+
return self.value // other # type: ignore
|
|
102
|
+
|
|
103
|
+
def __rdiv__(self, other: T) -> T:
|
|
104
|
+
return other // self.value # type: ignore
|
|
105
|
+
|
|
106
|
+
def __truediv__(self, other: T) -> T:
|
|
107
|
+
return self.value / other # type: ignore
|
|
108
|
+
|
|
109
|
+
def __rtruediv__(self, other: T) -> T:
|
|
110
|
+
return other / self.value # type: ignore
|
|
111
|
+
|
|
112
|
+
def __lt__(self, other) -> bool:
|
|
113
|
+
return self.value < other
|
|
114
|
+
|
|
115
|
+
def __le__(self, other) -> bool:
|
|
116
|
+
return self.value <= other
|
|
117
|
+
|
|
118
|
+
def __gt__(self, other) -> bool:
|
|
119
|
+
return self.value > other
|
|
120
|
+
|
|
121
|
+
def __ge__(self, other) -> bool:
|
|
122
|
+
return self.value >= other
|
|
123
|
+
|
|
124
|
+
def __eq__(self, other) -> bool:
|
|
125
|
+
if isinstance(other, Schedule):
|
|
126
|
+
if self.start_value != other.start_value:
|
|
127
|
+
return False
|
|
128
|
+
if self.end_value != other.end_value:
|
|
129
|
+
return False
|
|
130
|
+
if self.n_steps != other.n_steps:
|
|
131
|
+
return False
|
|
132
|
+
if type(self) is not type(other):
|
|
133
|
+
return False
|
|
134
|
+
return self.value == other
|
|
135
|
+
|
|
136
|
+
def __ne__(self, other) -> bool:
|
|
137
|
+
return not (self.__eq__(other))
|
|
138
|
+
|
|
139
|
+
def __float__(self):
|
|
140
|
+
return self.value
|
|
141
|
+
|
|
142
|
+
def __int__(self) -> int:
|
|
143
|
+
return int(self.value)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass(eq=False)
|
|
147
|
+
class LinearSchedule(Schedule):
|
|
148
|
+
a: float
|
|
149
|
+
b: float
|
|
150
|
+
|
|
151
|
+
def __init__(self, start_value: float, end_value: float, n_steps: int):
|
|
152
|
+
super().__init__(start_value, end_value, n_steps)
|
|
153
|
+
self._current_value = self.start_value
|
|
154
|
+
# y = ax + b
|
|
155
|
+
self.a = (self.end_value - self.start_value) / self.n_steps
|
|
156
|
+
self.b = self.start_value
|
|
157
|
+
|
|
158
|
+
def _compute(self):
|
|
159
|
+
return self.a * (self._t) + self.b
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def value(self) -> float:
|
|
163
|
+
return self._current_value
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass(eq=False)
|
|
167
|
+
class ExpSchedule(Schedule):
|
|
168
|
+
"""Exponential schedule. After n_steps, the value will be min_value.
|
|
169
|
+
|
|
170
|
+
Update formula is next_value = start_value * (min_value / start_value) ** (step / (n - 1))
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
n_steps: int
|
|
174
|
+
base: float
|
|
175
|
+
last_update_step: int
|
|
176
|
+
|
|
177
|
+
def __init__(self, start_value: float, min_value: float, n_steps: int):
|
|
178
|
+
super().__init__(start_value, min_value, n_steps)
|
|
179
|
+
self.base = self.end_value / self.start_value
|
|
180
|
+
self.last_update_step = self.n_steps - 1
|
|
181
|
+
|
|
182
|
+
def _compute(self):
|
|
183
|
+
return self.start_value * (self.base) ** (self._t / (self.n_steps - 1))
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def value(self) -> float:
|
|
187
|
+
return self._current_value
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@dataclass(eq=False)
|
|
191
|
+
class ConstantSchedule(Schedule):
|
|
192
|
+
def __init__(self, value: float):
|
|
193
|
+
super().__init__(value, value, 0)
|
|
194
|
+
self._value = value
|
|
195
|
+
|
|
196
|
+
def update(self, step=None):
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def value(self) -> float:
|
|
201
|
+
return self._value
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@dataclass(eq=False)
|
|
205
|
+
class RoundedSchedule(Schedule):
|
|
206
|
+
def __init__(self, schedule: Schedule, n_digits: int):
|
|
207
|
+
super().__init__(schedule.start_value, schedule.end_value, schedule.n_steps)
|
|
208
|
+
self.schedule = schedule
|
|
209
|
+
self.n_digits = n_digits
|
|
210
|
+
|
|
211
|
+
def update(self, step: int | None = None):
|
|
212
|
+
return self.schedule.update(step)
|
|
213
|
+
|
|
214
|
+
def _compute(self) -> float:
|
|
215
|
+
return self.schedule._compute()
|
|
216
|
+
|
|
217
|
+
@property
|
|
218
|
+
def value(self) -> float:
|
|
219
|
+
return round(self.schedule.value, self.n_digits)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@dataclass(eq=False)
|
|
223
|
+
class MultiSchedule(Schedule):
|
|
224
|
+
def __init__(self, schedules: dict[int, Schedule]):
|
|
225
|
+
ordered_schedules, ordered_steps = MultiSchedule._verify(schedules)
|
|
226
|
+
n_steps = ordered_steps[-1] + ordered_schedules[-1].n_steps
|
|
227
|
+
super().__init__(ordered_schedules[0].start_value, ordered_schedules[-1].end_value, n_steps)
|
|
228
|
+
self.schedules = iter(ordered_schedules)
|
|
229
|
+
self.current_schedule = next(self.schedules)
|
|
230
|
+
self.offset = 0
|
|
231
|
+
self.current_end = ordered_steps[1]
|
|
232
|
+
|
|
233
|
+
@staticmethod
|
|
234
|
+
def _verify(schedules: dict[int, Schedule]):
|
|
235
|
+
sorted_steps = sorted(schedules.keys())
|
|
236
|
+
sorted_schedules = [schedules[t] for t in sorted_steps]
|
|
237
|
+
if sorted_steps[0] != 0:
|
|
238
|
+
raise ValueError("First schedule must start at t=0")
|
|
239
|
+
current_step = 0
|
|
240
|
+
for i in range(len(sorted_steps)):
|
|
241
|
+
# Artificially set the end step of ConstantSchedules to the next step
|
|
242
|
+
if isinstance(sorted_schedules[i], ConstantSchedule):
|
|
243
|
+
if i + 1 < len(sorted_steps):
|
|
244
|
+
sorted_schedules[i].n_steps = sorted_steps[i + 1] - sorted_steps[i]
|
|
245
|
+
if sorted_steps[i] != current_step:
|
|
246
|
+
raise ValueError(f"Schedules are not contiguous at t={current_step}")
|
|
247
|
+
current_step += sorted_schedules[i].n_steps
|
|
248
|
+
return sorted_schedules, sorted_steps
|
|
249
|
+
|
|
250
|
+
def update(self, step: int | None = None):
|
|
251
|
+
if step is not None:
|
|
252
|
+
raise NotImplementedError("Cannot update MultiSchedule with a specific step")
|
|
253
|
+
super().update(step)
|
|
254
|
+
# If we reach the end of the current schedule, update to the next one
|
|
255
|
+
# except if we are at the end.
|
|
256
|
+
if self._t == self.current_end and self._t < self.n_steps:
|
|
257
|
+
self.current_schedule = next(self.schedules)
|
|
258
|
+
self.offset = self._t
|
|
259
|
+
self.current_end += self.current_schedule.n_steps
|
|
260
|
+
self.current_schedule.update(self.relative_step)
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def relative_step(self):
|
|
264
|
+
return self._t - self.offset
|
|
265
|
+
|
|
266
|
+
def _compute(self) -> float:
|
|
267
|
+
return self.current_schedule._compute()
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def value(self):
|
|
271
|
+
return self.current_schedule.value
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@dataclass(eq=False)
|
|
275
|
+
class ArbitrarySchedule(Schedule):
|
|
276
|
+
def __init__(self, fn: Callable[[int], float], n_steps: int):
|
|
277
|
+
super().__init__(fn(0), fn(n_steps), n_steps)
|
|
278
|
+
self._func = fn
|
|
279
|
+
|
|
280
|
+
def _compute(self) -> float:
|
|
281
|
+
return self._func(self._t)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: multi-agent-rlenv
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.4.0
|
|
4
4
|
Summary: A strongly typed Multi-Agent Reinforcement Learning framework
|
|
5
5
|
Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
|
|
6
6
|
Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
marlenv/__init__.py,sha256=
|
|
1
|
+
marlenv/__init__.py,sha256=sDucG8AdBbAOoO21QpTuMymZSPvp5fq6EBqCgTzFLkk,3741
|
|
2
2
|
marlenv/env_builder.py,sha256=s_lQANqP3iNc8nmcr3CanRVsExnn9qh0ihh4lFr0c4c,5560
|
|
3
3
|
marlenv/env_pool.py,sha256=R3WIrnQ5Zvff4HR1ecfkDmuO2zl7v1ywQ0K2_nvWFzs,1070
|
|
4
4
|
marlenv/exceptions.py,sha256=gJUC_2rVAvOfK_ypVFc7Myh-pIfSU3To38VBVS_0rZA,1179
|
|
@@ -6,18 +6,20 @@ marlenv/mock_env.py,sha256=qB0fYFIfbopJf7Va8kCeVI5vsOy1-2JdEYe9gdV1Ruw,4761
|
|
|
6
6
|
marlenv/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
7
|
marlenv/adapters/__init__.py,sha256=rWiqQOqTx3kVL5ZkPo3rkczrlQBBhQbU55zGI26SEeY,929
|
|
8
8
|
marlenv/adapters/gym_adapter.py,sha256=6CBEjANViTJBTUBmtVyrhJrzjBJxNs_4hmMnXXG2mkU,2906
|
|
9
|
-
marlenv/adapters/overcooked_adapter.py,sha256=
|
|
9
|
+
marlenv/adapters/overcooked_adapter.py,sha256=gu0TOp-FiLUVOskDnaqGA2D44P0qJNk4KWznHA2M8L8,9174
|
|
10
10
|
marlenv/adapters/pettingzoo_adapter.py,sha256=4F1au6uctsqRhGfcZOeDRH-8hmrFXnA5xH1Z1Pnek3s,2870
|
|
11
11
|
marlenv/adapters/pymarl_adapter.py,sha256=x__E90XpFbfSWhnBHtkcD6WYkmKki1LByNbUFoDBUcg,3416
|
|
12
12
|
marlenv/adapters/smac_adapter.py,sha256=fOfKo1hL4ioKtM5qQGcwtfdkdwUEACjAZqaGmkoQUcU,8373
|
|
13
13
|
marlenv/models/__init__.py,sha256=9M-rnj94nsdyO4zm_VEtyYBmde3iD2_eIY4bMB-IqCo,555
|
|
14
|
-
marlenv/models/env.py,sha256=
|
|
15
|
-
marlenv/models/episode.py,sha256=
|
|
14
|
+
marlenv/models/env.py,sha256=54I6SWkzO3nXZ0L7bRWpKJH_ywDu8iS-S2YwHL3yXDU,7569
|
|
15
|
+
marlenv/models/episode.py,sha256=hExUIcOjXImA-hfOgIWnq6sQPJAgZYhT6pQ1x4SBAjk,15138
|
|
16
16
|
marlenv/models/observation.py,sha256=kAmh1hIoC2TGrZlGVzV0y4TXXCSrI7gcmG0raeoncYk,3153
|
|
17
17
|
marlenv/models/spaces.py,sha256=pw8Sum_fHBkR-lyfTqUij4azMCNm8oBZrYZe4WVR7rA,7652
|
|
18
18
|
marlenv/models/state.py,sha256=958PXTHadi3gtRnhGgcGtqBnF44R11kdcx62NN2gwxA,1717
|
|
19
19
|
marlenv/models/step.py,sha256=LKGAV2Cu-k9Gz1hwrfvGx51l8axtQRqDE9WVL5r2A1Q,3037
|
|
20
20
|
marlenv/models/transition.py,sha256=2vvuhSSq911weCXio9nuyfsLVh_7ORSU_znOqpLLdLg,5107
|
|
21
|
+
marlenv/utils/__init__.py,sha256=C3qhvkVwctBP8mG3G5nkAZ5DKfErVRkdbHo7oeWVsM0,209
|
|
22
|
+
marlenv/utils/schedule.py,sha256=slhtpQiBHSUNyPmSkKb2yBgiHJqPhoPxa33GxvyV8Jc,8565
|
|
21
23
|
marlenv/wrappers/__init__.py,sha256=wl23NUYcl0vPJb2QLpe4Xj8ZocUIOarAZX8CgWqdSQE,808
|
|
22
24
|
marlenv/wrappers/agent_id_wrapper.py,sha256=oTIAYxKD1JtHfrZN43mf-3e8pxjd0nxm07vxs3BfrGY,1187
|
|
23
25
|
marlenv/wrappers/available_actions_mask.py,sha256=JoCJ9eqHlkY8wfY-oaceEi8yp1Efs1iK6IO2Ibf9oZA,1468
|
|
@@ -31,7 +33,7 @@ marlenv/wrappers/penalty_wrapper.py,sha256=v4_H8OEN2-yujLzRb6P7W7KwmXHtjAFsxcdp3
|
|
|
31
33
|
marlenv/wrappers/rlenv_wrapper.py,sha256=C2XekgBIM4x3Wa2Mtsn7rihRD4ymC2hORI473Af0sfw,2962
|
|
32
34
|
marlenv/wrappers/time_limit.py,sha256=CDIMMJPMyIDHSFxUJaC7nb7Kd86-07NgZeFhrpZm82o,3985
|
|
33
35
|
marlenv/wrappers/video_recorder.py,sha256=d5AFu6qHqby9mOcBsYWYPxAPiK1vtnfMYdZ81AnCekI,2624
|
|
34
|
-
multi_agent_rlenv-3.
|
|
35
|
-
multi_agent_rlenv-3.
|
|
36
|
-
multi_agent_rlenv-3.
|
|
37
|
-
multi_agent_rlenv-3.
|
|
36
|
+
multi_agent_rlenv-3.4.0.dist-info/METADATA,sha256=5TitLGgWA_BrpC_xhFHDufA4QB7v1EZ0FHP3Hdtbf5k,4897
|
|
37
|
+
multi_agent_rlenv-3.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
38
|
+
multi_agent_rlenv-3.4.0.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
|
|
39
|
+
multi_agent_rlenv-3.4.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|