multi-agent-rlenv 3.3.6__py3-none-any.whl → 3.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marlenv/__init__.py CHANGED
@@ -62,7 +62,7 @@ print(env.extras_shape) # (1, )
62
62
  If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
63
63
  """
64
64
 
65
- __version__ = "3.3.6"
65
+ __version__ = "3.3.7"
66
66
 
67
67
  from . import models
68
68
  from . import wrappers
@@ -2,6 +2,7 @@ import sys
2
2
  from dataclasses import dataclass
3
3
  from typing import Literal, Sequence
4
4
  from copy import deepcopy
5
+ from time import time
5
6
 
6
7
  import cv2
7
8
  import numpy as np
@@ -17,8 +18,10 @@ from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
17
18
  @dataclass
18
19
  class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
19
20
  horizon: int
21
+ reward_shaping: bool
20
22
 
21
- def __init__(self, oenv: OvercookedEnv):
23
+ def __init__(self, oenv: OvercookedEnv, reward_shaping: bool = True):
24
+ self.reward_shaping = reward_shaping
22
25
  self._oenv = oenv
23
26
  assert isinstance(oenv.mdp, OvercookedGridworld)
24
27
  self._mdp = oenv.mdp
@@ -86,10 +89,12 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
86
89
  def step(self, actions: Sequence[int] | npt.NDArray[np.int32 | np.int64]) -> Step:
87
90
  actions = [Action.ALL_ACTIONS[a] for a in actions]
88
91
  _, reward, done, info = self._oenv.step(actions, display_phi=True)
92
+ if self.reward_shaping:
93
+ reward += sum(info["shaped_r_by_agent"])
89
94
  return Step(
90
95
  obs=self.get_observation(),
91
96
  state=self.get_state(),
92
- reward=np.array([reward]),
97
+ reward=np.array([reward], dtype=np.float32),
93
98
  done=done,
94
99
  truncated=False,
95
100
  info=info,
@@ -185,6 +190,32 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
185
190
  "you_shall_not_pass",
186
191
  ],
187
192
  horizon: int = 400,
193
+ reward_shaping: bool = True,
188
194
  ):
189
195
  mdp = OvercookedGridworld.from_layout_name(layout)
190
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon))
196
+ return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon), reward_shaping=reward_shaping)
197
+
198
+ @staticmethod
199
+ def from_grid(
200
+ grid: Sequence[Sequence[Literal["S", "P", "X", "O", "D", "T", "1", "2", " "] | str]],
201
+ horizon: int = 400,
202
+ reward_shaping: bool = True,
203
+ ):
204
+ """
205
+ Create an Overcooked environment from a grid layout where
206
+ - S is a serving location
207
+ - P is a cooking pot
208
+ - X is a counter
209
+ - O is an onion dispenser
210
+ - D is a dish dispenser
211
+ - T is a tomato dispenser
212
+ - 1 is a player 1 starting location
213
+ - 2 is a player 2 starting location
214
+ - ' ' is a walkable space
215
+ """
216
+ # It is necessary to add an explicit layout name because Overcooked saves some files under this
217
+ # name. By default the name is a concatenation of the grid elements, which may include characters
218
+ # such as white spaces, pipes ('|') and square brackets ('[' and ']') that are invalid Windows file paths.
219
+ layout_name = str(time())
220
+ mdp = OvercookedGridworld.from_grid(grid, base_layout_params={"layout_name": layout_name})
221
+ return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon), reward_shaping=reward_shaping)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.3.6
3
+ Version: 3.3.7
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
@@ -1,4 +1,4 @@
1
- marlenv/__init__.py,sha256=iEVXbl4mQmey7P2uFdeKEYWEmZ8QxNS_f52jNdw4nZs,3741
1
+ marlenv/__init__.py,sha256=u27-QdgKv_1k3uR0oCBN7wcX2jRPmpICbJz1SaZ-f-A,3741
2
2
  marlenv/env_builder.py,sha256=s_lQANqP3iNc8nmcr3CanRVsExnn9qh0ihh4lFr0c4c,5560
3
3
  marlenv/env_pool.py,sha256=R3WIrnQ5Zvff4HR1ecfkDmuO2zl7v1ywQ0K2_nvWFzs,1070
4
4
  marlenv/exceptions.py,sha256=gJUC_2rVAvOfK_ypVFc7Myh-pIfSU3To38VBVS_0rZA,1179
@@ -6,7 +6,7 @@ marlenv/mock_env.py,sha256=qB0fYFIfbopJf7Va8kCeVI5vsOy1-2JdEYe9gdV1Ruw,4761
6
6
  marlenv/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  marlenv/adapters/__init__.py,sha256=rWiqQOqTx3kVL5ZkPo3rkczrlQBBhQbU55zGI26SEeY,929
8
8
  marlenv/adapters/gym_adapter.py,sha256=6CBEjANViTJBTUBmtVyrhJrzjBJxNs_4hmMnXXG2mkU,2906
9
- marlenv/adapters/overcooked_adapter.py,sha256=JZhB50cQGWGjaHWuPwskUKr6YthEptpYC3cD7i9GVvk,6832
9
+ marlenv/adapters/overcooked_adapter.py,sha256=Ehwwha_gh9wsQWBVLvwKYR_P6WUco-W2LoxumVjXSPQ,8289
10
10
  marlenv/adapters/pettingzoo_adapter.py,sha256=4F1au6uctsqRhGfcZOeDRH-8hmrFXnA5xH1Z1Pnek3s,2870
11
11
  marlenv/adapters/pymarl_adapter.py,sha256=x__E90XpFbfSWhnBHtkcD6WYkmKki1LByNbUFoDBUcg,3416
12
12
  marlenv/adapters/smac_adapter.py,sha256=fOfKo1hL4ioKtM5qQGcwtfdkdwUEACjAZqaGmkoQUcU,8373
@@ -31,7 +31,7 @@ marlenv/wrappers/penalty_wrapper.py,sha256=v4_H8OEN2-yujLzRb6P7W7KwmXHtjAFsxcdp3
31
31
  marlenv/wrappers/rlenv_wrapper.py,sha256=C2XekgBIM4x3Wa2Mtsn7rihRD4ymC2hORI473Af0sfw,2962
32
32
  marlenv/wrappers/time_limit.py,sha256=CDIMMJPMyIDHSFxUJaC7nb7Kd86-07NgZeFhrpZm82o,3985
33
33
  marlenv/wrappers/video_recorder.py,sha256=d5AFu6qHqby9mOcBsYWYPxAPiK1vtnfMYdZ81AnCekI,2624
34
- multi_agent_rlenv-3.3.6.dist-info/METADATA,sha256=oHsLxFw-wlgzPyswB6r3QIWZWM_injRIFJuJczyZDTo,4897
35
- multi_agent_rlenv-3.3.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
- multi_agent_rlenv-3.3.6.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
37
- multi_agent_rlenv-3.3.6.dist-info/RECORD,,
34
+ multi_agent_rlenv-3.3.7.dist-info/METADATA,sha256=zAUWp4QbbEnxN7EVkJfJAXKYEjAcfUIKaogM43fAjR8,4897
35
+ multi_agent_rlenv-3.3.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
+ multi_agent_rlenv-3.3.7.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
37
+ multi_agent_rlenv-3.3.7.dist-info/RECORD,,