multi-agent-rlenv 3.3.6__py3-none-any.whl → 3.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- marlenv/__init__.py +1 -1
- marlenv/adapters/overcooked_adapter.py +60 -8
- marlenv/models/env.py +8 -3
- marlenv/models/episode.py +2 -2
- marlenv/utils/__init__.py +10 -0
- marlenv/utils/schedule.py +281 -0
- {multi_agent_rlenv-3.3.6.dist-info → multi_agent_rlenv-3.4.0.dist-info}/METADATA +1 -1
- {multi_agent_rlenv-3.3.6.dist-info → multi_agent_rlenv-3.4.0.dist-info}/RECORD +10 -8
- {multi_agent_rlenv-3.3.6.dist-info → multi_agent_rlenv-3.4.0.dist-info}/WHEEL +0 -0
- {multi_agent_rlenv-3.3.6.dist-info → multi_agent_rlenv-3.4.0.dist-info}/licenses/LICENSE +0 -0
marlenv/__init__.py
CHANGED
|
@@ -62,7 +62,7 @@ print(env.extras_shape) # (1, )
|
|
|
62
62
|
If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
|
|
63
63
|
"""
|
|
64
64
|
|
|
65
|
-
__version__ = "3.
|
|
65
|
+
__version__ = "3.4.0"
|
|
66
66
|
|
|
67
67
|
from . import models
|
|
68
68
|
from . import wrappers
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Literal, Sequence
|
|
3
|
+
from typing import Literal, Sequence, Optional
|
|
4
4
|
from copy import deepcopy
|
|
5
5
|
|
|
6
6
|
import cv2
|
|
@@ -8,6 +8,7 @@ import numpy as np
|
|
|
8
8
|
import numpy.typing as npt
|
|
9
9
|
import pygame
|
|
10
10
|
from marlenv.models import ContinuousSpace, DiscreteActionSpace, MARLEnv, Observation, State, Step
|
|
11
|
+
from marlenv.utils import Schedule
|
|
11
12
|
|
|
12
13
|
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
|
|
13
14
|
from overcooked_ai_py.mdp.overcooked_mdp import Action, OvercookedGridworld, OvercookedState
|
|
@@ -17,8 +18,17 @@ from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
|
|
|
17
18
|
@dataclass
|
|
18
19
|
class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
19
20
|
horizon: int
|
|
21
|
+
shaping_factor: Schedule
|
|
20
22
|
|
|
21
|
-
def __init__(
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
oenv: OvercookedEnv,
|
|
26
|
+
shaping_factor: float | Schedule = 1.0,
|
|
27
|
+
name_suffix: Optional[str] = None,
|
|
28
|
+
):
|
|
29
|
+
if isinstance(shaping_factor, (int, float)):
|
|
30
|
+
shaping_factor = Schedule.constant(shaping_factor)
|
|
31
|
+
self.shaping_factor = shaping_factor
|
|
22
32
|
self._oenv = oenv
|
|
23
33
|
assert isinstance(oenv.mdp, OvercookedGridworld)
|
|
24
34
|
self._mdp = oenv.mdp
|
|
@@ -40,6 +50,8 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
40
50
|
reward_space=ContinuousSpace.from_shape(1),
|
|
41
51
|
)
|
|
42
52
|
self.horizon = int(self._oenv.horizon)
|
|
53
|
+
if name_suffix is not None:
|
|
54
|
+
self.name = f"{self.name}-{name_suffix}"
|
|
43
55
|
|
|
44
56
|
@property
|
|
45
57
|
def state(self) -> OvercookedState:
|
|
@@ -84,12 +96,15 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
84
96
|
return np.array(available_actions, dtype=np.bool)
|
|
85
97
|
|
|
86
98
|
def step(self, actions: Sequence[int] | npt.NDArray[np.int32 | np.int64]) -> Step:
|
|
99
|
+
self.shaping_factor.update()
|
|
87
100
|
actions = [Action.ALL_ACTIONS[a] for a in actions]
|
|
88
101
|
_, reward, done, info = self._oenv.step(actions, display_phi=True)
|
|
102
|
+
|
|
103
|
+
reward += sum(info["shaped_r_by_agent"]) * self.shaping_factor
|
|
89
104
|
return Step(
|
|
90
105
|
obs=self.get_observation(),
|
|
91
106
|
state=self.get_state(),
|
|
92
|
-
reward=np.array([reward]),
|
|
107
|
+
reward=np.array([reward], dtype=np.float32),
|
|
93
108
|
done=done,
|
|
94
109
|
truncated=False,
|
|
95
110
|
info=info,
|
|
@@ -99,19 +114,25 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
99
114
|
self._oenv.reset()
|
|
100
115
|
return self.get_observation(), self.get_state()
|
|
101
116
|
|
|
102
|
-
def __deepcopy__(self,
|
|
117
|
+
def __deepcopy__(self, _):
|
|
118
|
+
"""
|
|
119
|
+
Note: a specific implementation is needed because `pygame.font.Font` objects are not deep-copiable by default.
|
|
120
|
+
"""
|
|
103
121
|
mdp = deepcopy(self._mdp)
|
|
104
|
-
|
|
122
|
+
copy = Overcooked(OvercookedEnv.from_mdp(mdp, horizon=self.horizon), deepcopy(self.shaping_factor))
|
|
123
|
+
copy.name = self.name
|
|
124
|
+
return copy
|
|
105
125
|
|
|
106
126
|
def __getstate__(self):
|
|
107
|
-
return {"horizon": self.horizon, "mdp": self._mdp}
|
|
127
|
+
return {"horizon": self.horizon, "mdp": self._mdp, "name": self.name, "schedule": self.shaping_factor}
|
|
108
128
|
|
|
109
129
|
def __setstate__(self, state: dict):
|
|
110
130
|
from overcooked_ai_py.mdp.overcooked_mdp import Recipe
|
|
111
131
|
|
|
112
132
|
mdp = state["mdp"]
|
|
113
133
|
Recipe.configure(mdp.recipe_config)
|
|
114
|
-
self.__init__(OvercookedEnv.from_mdp(state["mdp"], horizon=state["horizon"]))
|
|
134
|
+
self.__init__(OvercookedEnv.from_mdp(state["mdp"], horizon=state["horizon"]), shaping_factor=state["schedule"])
|
|
135
|
+
self.name = state["name"]
|
|
115
136
|
|
|
116
137
|
def get_image(self):
|
|
117
138
|
rewards_dict = {} # dictionary of details you want rendered in the UI
|
|
@@ -185,6 +206,37 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
185
206
|
"you_shall_not_pass",
|
|
186
207
|
],
|
|
187
208
|
horizon: int = 400,
|
|
209
|
+
reward_shaping_factor: float | Schedule = 1.0,
|
|
188
210
|
):
|
|
189
211
|
mdp = OvercookedGridworld.from_layout_name(layout)
|
|
190
|
-
return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon))
|
|
212
|
+
return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), reward_shaping_factor, layout)
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def from_grid(
|
|
216
|
+
grid: Sequence[Sequence[Literal["S", "P", "X", "O", "D", "T", "1", "2", " "] | str]],
|
|
217
|
+
horizon: int = 400,
|
|
218
|
+
shaping_factor: float | Schedule = 1.0,
|
|
219
|
+
layout_name: Optional[str] = None,
|
|
220
|
+
):
|
|
221
|
+
"""
|
|
222
|
+
Create an Overcooked environment from a grid layout where
|
|
223
|
+
- S is a serving location
|
|
224
|
+
- P is a cooking pot
|
|
225
|
+
- X is a counter
|
|
226
|
+
- O is an onion dispenser
|
|
227
|
+
- D is a dish dispenser
|
|
228
|
+
- T is a tomato dispenser
|
|
229
|
+
- 1 is a player 1 starting location
|
|
230
|
+
- 2 is a player 2 starting location
|
|
231
|
+
- ' ' is a walkable space
|
|
232
|
+
|
|
233
|
+
If provided, `custom_name` is added to the environment name.
|
|
234
|
+
"""
|
|
235
|
+
# It is necessary to add an explicit layout name because Overcooked saves some files under this
|
|
236
|
+
# name. By default the name is a concatenation of the grid elements, which may include characters
|
|
237
|
+
# such as white spaces, pipes ('|') and square brackets ('[' and ']') that are invalid Windows file paths.
|
|
238
|
+
if layout_name is None:
|
|
239
|
+
layout_name = "custom-layout"
|
|
240
|
+
mdp = OvercookedGridworld.from_grid(grid, base_layout_params={"layout_name": layout_name})
|
|
241
|
+
|
|
242
|
+
return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), shaping_factor, layout_name)
|
marlenv/models/env.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from itertools import product
|
|
4
|
-
from typing import Generic, Optional, Sequence
|
|
4
|
+
from typing import Any, Generic, Optional, Sequence
|
|
5
5
|
|
|
6
6
|
import cv2
|
|
7
7
|
import numpy as np
|
|
@@ -13,8 +13,8 @@ from .spaces import ActionSpace, ContinuousSpace, Space
|
|
|
13
13
|
from .state import State
|
|
14
14
|
from .step import Step
|
|
15
15
|
|
|
16
|
-
ActionType = TypeVar("ActionType", default=
|
|
17
|
-
ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=
|
|
16
|
+
ActionType = TypeVar("ActionType", default=Any)
|
|
17
|
+
ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=Any)
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@dataclass
|
|
@@ -108,6 +108,11 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
|
|
|
108
108
|
"""Whether the environment is multi-objective."""
|
|
109
109
|
return self.reward_space.size > 1
|
|
110
110
|
|
|
111
|
+
@property
|
|
112
|
+
def n_objectives(self) -> int:
|
|
113
|
+
"""The number of objectives in the environment."""
|
|
114
|
+
return self.reward_space.size
|
|
115
|
+
|
|
111
116
|
def sample_action(self) -> ActionType:
|
|
112
117
|
"""Sample an available action from the action space."""
|
|
113
118
|
return self.action_space.sample(self.available_actions()) # type: ignore
|
marlenv/models/episode.py
CHANGED
|
@@ -179,9 +179,9 @@ class Episode(Generic[A]):
|
|
|
179
179
|
@cached_property
|
|
180
180
|
def dones(self):
|
|
181
181
|
"""The done flags for each transition"""
|
|
182
|
-
dones = np.zeros_like(self.rewards, dtype=np.
|
|
182
|
+
dones = np.zeros_like(self.rewards, dtype=np.bool)
|
|
183
183
|
if self.is_done:
|
|
184
|
-
dones[self.episode_len - 1 :] =
|
|
184
|
+
dones[self.episode_len - 1 :] = True
|
|
185
185
|
return dones
|
|
186
186
|
|
|
187
187
|
@property
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Callable, Optional, TypeVar
|
|
4
|
+
|
|
5
|
+
T = TypeVar("T")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Schedule:
|
|
10
|
+
"""
|
|
11
|
+
Schedules the value of a varaible over time.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
name: str
|
|
15
|
+
start_value: float
|
|
16
|
+
end_value: float
|
|
17
|
+
_t: int
|
|
18
|
+
n_steps: int
|
|
19
|
+
|
|
20
|
+
def __init__(self, start_value: float, end_value: float, n_steps: int):
|
|
21
|
+
self.start_value = start_value
|
|
22
|
+
self.end_value = end_value
|
|
23
|
+
self.n_steps = n_steps
|
|
24
|
+
self.name = self.__class__.__name__
|
|
25
|
+
self._t = 0
|
|
26
|
+
self._current_value = self.start_value
|
|
27
|
+
|
|
28
|
+
def update(self, step: Optional[int] = None):
|
|
29
|
+
"""Update the value of the schedule. Force a step if given."""
|
|
30
|
+
if step is not None:
|
|
31
|
+
self._t = step
|
|
32
|
+
else:
|
|
33
|
+
self._t += 1
|
|
34
|
+
if self._t >= self.n_steps:
|
|
35
|
+
self._current_value = self.end_value
|
|
36
|
+
else:
|
|
37
|
+
self._current_value = self._compute()
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def _compute(self) -> float:
|
|
41
|
+
"""Compute the value of the schedule"""
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def value(self) -> float:
|
|
45
|
+
"""Returns the current value of the schedule"""
|
|
46
|
+
return self._current_value
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def constant(value: float):
|
|
50
|
+
return ConstantSchedule(value)
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def linear(start_value: float, end_value: float, n_steps: int):
|
|
54
|
+
return LinearSchedule(start_value, end_value, n_steps)
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def exp(start_value: float, end_value: float, n_steps: int):
|
|
58
|
+
return ExpSchedule(start_value, end_value, n_steps)
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def arbitrary(func: Callable[[int], float], n_steps: Optional[int] = None):
|
|
62
|
+
if n_steps is None:
|
|
63
|
+
n_steps = 0
|
|
64
|
+
return ArbitrarySchedule(func, n_steps)
|
|
65
|
+
|
|
66
|
+
def rounded(self, n_digits: int = 0) -> "RoundedSchedule":
|
|
67
|
+
return RoundedSchedule(self, n_digits)
|
|
68
|
+
|
|
69
|
+
# Operator overloading
|
|
70
|
+
def __mul__(self, other: T) -> T:
|
|
71
|
+
return self.value * other # type: ignore
|
|
72
|
+
|
|
73
|
+
def __rmul__(self, other: T) -> T:
|
|
74
|
+
return self.value * other # type: ignore
|
|
75
|
+
|
|
76
|
+
def __pow__(self, exp: float) -> float:
|
|
77
|
+
return self.value**exp
|
|
78
|
+
|
|
79
|
+
def __rpow__(self, other: T) -> T:
|
|
80
|
+
return other**self.value # type: ignore
|
|
81
|
+
|
|
82
|
+
def __add__(self, other: T) -> T:
|
|
83
|
+
return self.value + other # type: ignore
|
|
84
|
+
|
|
85
|
+
def __radd__(self, other: T) -> T:
|
|
86
|
+
return self.value + other # type: ignore
|
|
87
|
+
|
|
88
|
+
def __neg__(self):
|
|
89
|
+
return -self.value
|
|
90
|
+
|
|
91
|
+
def __pos__(self):
|
|
92
|
+
return +self.value
|
|
93
|
+
|
|
94
|
+
def __sub__(self, other: T) -> T:
|
|
95
|
+
return self.value - other # type: ignore
|
|
96
|
+
|
|
97
|
+
def __rsub__(self, other: T) -> T:
|
|
98
|
+
return other - self.value # type: ignore
|
|
99
|
+
|
|
100
|
+
def __div__(self, other: T) -> T:
|
|
101
|
+
return self.value // other # type: ignore
|
|
102
|
+
|
|
103
|
+
def __rdiv__(self, other: T) -> T:
|
|
104
|
+
return other // self.value # type: ignore
|
|
105
|
+
|
|
106
|
+
def __truediv__(self, other: T) -> T:
|
|
107
|
+
return self.value / other # type: ignore
|
|
108
|
+
|
|
109
|
+
def __rtruediv__(self, other: T) -> T:
|
|
110
|
+
return other / self.value # type: ignore
|
|
111
|
+
|
|
112
|
+
def __lt__(self, other) -> bool:
|
|
113
|
+
return self.value < other
|
|
114
|
+
|
|
115
|
+
def __le__(self, other) -> bool:
|
|
116
|
+
return self.value <= other
|
|
117
|
+
|
|
118
|
+
def __gt__(self, other) -> bool:
|
|
119
|
+
return self.value > other
|
|
120
|
+
|
|
121
|
+
def __ge__(self, other) -> bool:
|
|
122
|
+
return self.value >= other
|
|
123
|
+
|
|
124
|
+
def __eq__(self, other) -> bool:
|
|
125
|
+
if isinstance(other, Schedule):
|
|
126
|
+
if self.start_value != other.start_value:
|
|
127
|
+
return False
|
|
128
|
+
if self.end_value != other.end_value:
|
|
129
|
+
return False
|
|
130
|
+
if self.n_steps != other.n_steps:
|
|
131
|
+
return False
|
|
132
|
+
if type(self) is not type(other):
|
|
133
|
+
return False
|
|
134
|
+
return self.value == other
|
|
135
|
+
|
|
136
|
+
def __ne__(self, other) -> bool:
|
|
137
|
+
return not (self.__eq__(other))
|
|
138
|
+
|
|
139
|
+
def __float__(self):
|
|
140
|
+
return self.value
|
|
141
|
+
|
|
142
|
+
def __int__(self) -> int:
|
|
143
|
+
return int(self.value)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass(eq=False)
|
|
147
|
+
class LinearSchedule(Schedule):
|
|
148
|
+
a: float
|
|
149
|
+
b: float
|
|
150
|
+
|
|
151
|
+
def __init__(self, start_value: float, end_value: float, n_steps: int):
|
|
152
|
+
super().__init__(start_value, end_value, n_steps)
|
|
153
|
+
self._current_value = self.start_value
|
|
154
|
+
# y = ax + b
|
|
155
|
+
self.a = (self.end_value - self.start_value) / self.n_steps
|
|
156
|
+
self.b = self.start_value
|
|
157
|
+
|
|
158
|
+
def _compute(self):
|
|
159
|
+
return self.a * (self._t) + self.b
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def value(self) -> float:
|
|
163
|
+
return self._current_value
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass(eq=False)
|
|
167
|
+
class ExpSchedule(Schedule):
|
|
168
|
+
"""Exponential schedule. After n_steps, the value will be min_value.
|
|
169
|
+
|
|
170
|
+
Update formula is next_value = start_value * (min_value / start_value) ** (step / (n - 1))
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
n_steps: int
|
|
174
|
+
base: float
|
|
175
|
+
last_update_step: int
|
|
176
|
+
|
|
177
|
+
def __init__(self, start_value: float, min_value: float, n_steps: int):
|
|
178
|
+
super().__init__(start_value, min_value, n_steps)
|
|
179
|
+
self.base = self.end_value / self.start_value
|
|
180
|
+
self.last_update_step = self.n_steps - 1
|
|
181
|
+
|
|
182
|
+
def _compute(self):
|
|
183
|
+
return self.start_value * (self.base) ** (self._t / (self.n_steps - 1))
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def value(self) -> float:
|
|
187
|
+
return self._current_value
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@dataclass(eq=False)
|
|
191
|
+
class ConstantSchedule(Schedule):
|
|
192
|
+
def __init__(self, value: float):
|
|
193
|
+
super().__init__(value, value, 0)
|
|
194
|
+
self._value = value
|
|
195
|
+
|
|
196
|
+
def update(self, step=None):
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def value(self) -> float:
|
|
201
|
+
return self._value
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@dataclass(eq=False)
|
|
205
|
+
class RoundedSchedule(Schedule):
|
|
206
|
+
def __init__(self, schedule: Schedule, n_digits: int):
|
|
207
|
+
super().__init__(schedule.start_value, schedule.end_value, schedule.n_steps)
|
|
208
|
+
self.schedule = schedule
|
|
209
|
+
self.n_digits = n_digits
|
|
210
|
+
|
|
211
|
+
def update(self, step: int | None = None):
|
|
212
|
+
return self.schedule.update(step)
|
|
213
|
+
|
|
214
|
+
def _compute(self) -> float:
|
|
215
|
+
return self.schedule._compute()
|
|
216
|
+
|
|
217
|
+
@property
|
|
218
|
+
def value(self) -> float:
|
|
219
|
+
return round(self.schedule.value, self.n_digits)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@dataclass(eq=False)
|
|
223
|
+
class MultiSchedule(Schedule):
|
|
224
|
+
def __init__(self, schedules: dict[int, Schedule]):
|
|
225
|
+
ordered_schedules, ordered_steps = MultiSchedule._verify(schedules)
|
|
226
|
+
n_steps = ordered_steps[-1] + ordered_schedules[-1].n_steps
|
|
227
|
+
super().__init__(ordered_schedules[0].start_value, ordered_schedules[-1].end_value, n_steps)
|
|
228
|
+
self.schedules = iter(ordered_schedules)
|
|
229
|
+
self.current_schedule = next(self.schedules)
|
|
230
|
+
self.offset = 0
|
|
231
|
+
self.current_end = ordered_steps[1]
|
|
232
|
+
|
|
233
|
+
@staticmethod
|
|
234
|
+
def _verify(schedules: dict[int, Schedule]):
|
|
235
|
+
sorted_steps = sorted(schedules.keys())
|
|
236
|
+
sorted_schedules = [schedules[t] for t in sorted_steps]
|
|
237
|
+
if sorted_steps[0] != 0:
|
|
238
|
+
raise ValueError("First schedule must start at t=0")
|
|
239
|
+
current_step = 0
|
|
240
|
+
for i in range(len(sorted_steps)):
|
|
241
|
+
# Artificially set the end step of ConstantSchedules to the next step
|
|
242
|
+
if isinstance(sorted_schedules[i], ConstantSchedule):
|
|
243
|
+
if i + 1 < len(sorted_steps):
|
|
244
|
+
sorted_schedules[i].n_steps = sorted_steps[i + 1] - sorted_steps[i]
|
|
245
|
+
if sorted_steps[i] != current_step:
|
|
246
|
+
raise ValueError(f"Schedules are not contiguous at t={current_step}")
|
|
247
|
+
current_step += sorted_schedules[i].n_steps
|
|
248
|
+
return sorted_schedules, sorted_steps
|
|
249
|
+
|
|
250
|
+
def update(self, step: int | None = None):
|
|
251
|
+
if step is not None:
|
|
252
|
+
raise NotImplementedError("Cannot update MultiSchedule with a specific step")
|
|
253
|
+
super().update(step)
|
|
254
|
+
# If we reach the end of the current schedule, update to the next one
|
|
255
|
+
# except if we are at the end.
|
|
256
|
+
if self._t == self.current_end and self._t < self.n_steps:
|
|
257
|
+
self.current_schedule = next(self.schedules)
|
|
258
|
+
self.offset = self._t
|
|
259
|
+
self.current_end += self.current_schedule.n_steps
|
|
260
|
+
self.current_schedule.update(self.relative_step)
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def relative_step(self):
|
|
264
|
+
return self._t - self.offset
|
|
265
|
+
|
|
266
|
+
def _compute(self) -> float:
|
|
267
|
+
return self.current_schedule._compute()
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def value(self):
|
|
271
|
+
return self.current_schedule.value
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@dataclass(eq=False)
|
|
275
|
+
class ArbitrarySchedule(Schedule):
|
|
276
|
+
def __init__(self, fn: Callable[[int], float], n_steps: int):
|
|
277
|
+
super().__init__(fn(0), fn(n_steps), n_steps)
|
|
278
|
+
self._func = fn
|
|
279
|
+
|
|
280
|
+
def _compute(self) -> float:
|
|
281
|
+
return self._func(self._t)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: multi-agent-rlenv
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.4.0
|
|
4
4
|
Summary: A strongly typed Multi-Agent Reinforcement Learning framework
|
|
5
5
|
Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
|
|
6
6
|
Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
marlenv/__init__.py,sha256=
|
|
1
|
+
marlenv/__init__.py,sha256=sDucG8AdBbAOoO21QpTuMymZSPvp5fq6EBqCgTzFLkk,3741
|
|
2
2
|
marlenv/env_builder.py,sha256=s_lQANqP3iNc8nmcr3CanRVsExnn9qh0ihh4lFr0c4c,5560
|
|
3
3
|
marlenv/env_pool.py,sha256=R3WIrnQ5Zvff4HR1ecfkDmuO2zl7v1ywQ0K2_nvWFzs,1070
|
|
4
4
|
marlenv/exceptions.py,sha256=gJUC_2rVAvOfK_ypVFc7Myh-pIfSU3To38VBVS_0rZA,1179
|
|
@@ -6,18 +6,20 @@ marlenv/mock_env.py,sha256=qB0fYFIfbopJf7Va8kCeVI5vsOy1-2JdEYe9gdV1Ruw,4761
|
|
|
6
6
|
marlenv/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
7
|
marlenv/adapters/__init__.py,sha256=rWiqQOqTx3kVL5ZkPo3rkczrlQBBhQbU55zGI26SEeY,929
|
|
8
8
|
marlenv/adapters/gym_adapter.py,sha256=6CBEjANViTJBTUBmtVyrhJrzjBJxNs_4hmMnXXG2mkU,2906
|
|
9
|
-
marlenv/adapters/overcooked_adapter.py,sha256=
|
|
9
|
+
marlenv/adapters/overcooked_adapter.py,sha256=gu0TOp-FiLUVOskDnaqGA2D44P0qJNk4KWznHA2M8L8,9174
|
|
10
10
|
marlenv/adapters/pettingzoo_adapter.py,sha256=4F1au6uctsqRhGfcZOeDRH-8hmrFXnA5xH1Z1Pnek3s,2870
|
|
11
11
|
marlenv/adapters/pymarl_adapter.py,sha256=x__E90XpFbfSWhnBHtkcD6WYkmKki1LByNbUFoDBUcg,3416
|
|
12
12
|
marlenv/adapters/smac_adapter.py,sha256=fOfKo1hL4ioKtM5qQGcwtfdkdwUEACjAZqaGmkoQUcU,8373
|
|
13
13
|
marlenv/models/__init__.py,sha256=9M-rnj94nsdyO4zm_VEtyYBmde3iD2_eIY4bMB-IqCo,555
|
|
14
|
-
marlenv/models/env.py,sha256=
|
|
15
|
-
marlenv/models/episode.py,sha256=
|
|
14
|
+
marlenv/models/env.py,sha256=54I6SWkzO3nXZ0L7bRWpKJH_ywDu8iS-S2YwHL3yXDU,7569
|
|
15
|
+
marlenv/models/episode.py,sha256=hExUIcOjXImA-hfOgIWnq6sQPJAgZYhT6pQ1x4SBAjk,15138
|
|
16
16
|
marlenv/models/observation.py,sha256=kAmh1hIoC2TGrZlGVzV0y4TXXCSrI7gcmG0raeoncYk,3153
|
|
17
17
|
marlenv/models/spaces.py,sha256=pw8Sum_fHBkR-lyfTqUij4azMCNm8oBZrYZe4WVR7rA,7652
|
|
18
18
|
marlenv/models/state.py,sha256=958PXTHadi3gtRnhGgcGtqBnF44R11kdcx62NN2gwxA,1717
|
|
19
19
|
marlenv/models/step.py,sha256=LKGAV2Cu-k9Gz1hwrfvGx51l8axtQRqDE9WVL5r2A1Q,3037
|
|
20
20
|
marlenv/models/transition.py,sha256=2vvuhSSq911weCXio9nuyfsLVh_7ORSU_znOqpLLdLg,5107
|
|
21
|
+
marlenv/utils/__init__.py,sha256=C3qhvkVwctBP8mG3G5nkAZ5DKfErVRkdbHo7oeWVsM0,209
|
|
22
|
+
marlenv/utils/schedule.py,sha256=slhtpQiBHSUNyPmSkKb2yBgiHJqPhoPxa33GxvyV8Jc,8565
|
|
21
23
|
marlenv/wrappers/__init__.py,sha256=wl23NUYcl0vPJb2QLpe4Xj8ZocUIOarAZX8CgWqdSQE,808
|
|
22
24
|
marlenv/wrappers/agent_id_wrapper.py,sha256=oTIAYxKD1JtHfrZN43mf-3e8pxjd0nxm07vxs3BfrGY,1187
|
|
23
25
|
marlenv/wrappers/available_actions_mask.py,sha256=JoCJ9eqHlkY8wfY-oaceEi8yp1Efs1iK6IO2Ibf9oZA,1468
|
|
@@ -31,7 +33,7 @@ marlenv/wrappers/penalty_wrapper.py,sha256=v4_H8OEN2-yujLzRb6P7W7KwmXHtjAFsxcdp3
|
|
|
31
33
|
marlenv/wrappers/rlenv_wrapper.py,sha256=C2XekgBIM4x3Wa2Mtsn7rihRD4ymC2hORI473Af0sfw,2962
|
|
32
34
|
marlenv/wrappers/time_limit.py,sha256=CDIMMJPMyIDHSFxUJaC7nb7Kd86-07NgZeFhrpZm82o,3985
|
|
33
35
|
marlenv/wrappers/video_recorder.py,sha256=d5AFu6qHqby9mOcBsYWYPxAPiK1vtnfMYdZ81AnCekI,2624
|
|
34
|
-
multi_agent_rlenv-3.
|
|
35
|
-
multi_agent_rlenv-3.
|
|
36
|
-
multi_agent_rlenv-3.
|
|
37
|
-
multi_agent_rlenv-3.
|
|
36
|
+
multi_agent_rlenv-3.4.0.dist-info/METADATA,sha256=5TitLGgWA_BrpC_xhFHDufA4QB7v1EZ0FHP3Hdtbf5k,4897
|
|
37
|
+
multi_agent_rlenv-3.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
38
|
+
multi_agent_rlenv-3.4.0.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
|
|
39
|
+
multi_agent_rlenv-3.4.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|