multi-agent-rlenv 3.3.3__tar.gz → 3.3.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/PKG-INFO +1 -1
  2. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/__init__.py +1 -1
  3. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/adapters/overcooked_adapter.py +9 -5
  4. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/models/env.py +2 -2
  5. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/tests/test_adapters.py +4 -0
  6. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/tests/test_serialization.py +15 -0
  7. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/.github/workflows/ci.yaml +0 -0
  8. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/.github/workflows/docs.yaml +0 -0
  9. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/.gitignore +0 -0
  10. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/LICENSE +0 -0
  11. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/README.md +0 -0
  12. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/pyproject.toml +0 -0
  13. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/adapters/__init__.py +0 -0
  14. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/adapters/gym_adapter.py +0 -0
  15. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/adapters/pettingzoo_adapter.py +0 -0
  16. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/adapters/pymarl_adapter.py +0 -0
  17. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/adapters/smac_adapter.py +0 -0
  18. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/env_builder.py +0 -0
  19. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/env_pool.py +0 -0
  20. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/exceptions.py +0 -0
  21. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/mock_env.py +0 -0
  22. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/models/__init__.py +0 -0
  23. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/models/episode.py +0 -0
  24. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/models/observation.py +0 -0
  25. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/models/spaces.py +0 -0
  26. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/models/state.py +0 -0
  27. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/models/step.py +0 -0
  28. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/models/transition.py +0 -0
  29. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/py.typed +0 -0
  30. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/__init__.py +0 -0
  31. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/agent_id_wrapper.py +0 -0
  32. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/available_actions_mask.py +0 -0
  33. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/available_actions_wrapper.py +0 -0
  34. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/blind_wrapper.py +0 -0
  35. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/centralised.py +0 -0
  36. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/delayed_rewards.py +0 -0
  37. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/last_action_wrapper.py +0 -0
  38. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/paddings.py +0 -0
  39. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/penalty_wrapper.py +0 -0
  40. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/rlenv_wrapper.py +0 -0
  41. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/time_limit.py +0 -0
  42. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/src/marlenv/wrappers/video_recorder.py +0 -0
  43. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/tests/__init__.py +0 -0
  44. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/tests/test_episode.py +0 -0
  45. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/tests/test_models.py +0 -0
  46. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/tests/test_pool.py +0 -0
  47. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/tests/test_spaces.py +0 -0
  48. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/tests/test_wrappers.py +0 -0
  49. {multi_agent_rlenv-3.3.3 → multi_agent_rlenv-3.3.5}/tests/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.3.3
3
+ Version: 3.3.5
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
@@ -62,7 +62,7 @@ print(env.extras_shape) # (1, )
62
62
  If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
63
63
  """
64
64
 
65
- __version__ = "3.3.3"
65
+ __version__ = "3.3.5"
66
66
 
67
67
  from . import models
68
68
  from . import wrappers
@@ -22,7 +22,7 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
22
22
  self._oenv = oenv
23
23
  assert isinstance(oenv.mdp, OvercookedGridworld)
24
24
  self._mdp = oenv.mdp
25
- self.visualizer = StateVisualizer()
25
+ self._visualizer = StateVisualizer()
26
26
  shape = tuple(int(s) for s in self._mdp.get_lossless_state_encoding_shape())
27
27
  shape = (shape[2], shape[0], shape[1])
28
28
  super().__init__(
@@ -53,19 +53,19 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
53
53
  return self.state.timestep
54
54
 
55
55
  def _state_data(self):
56
- state = np.array(self._mdp.lossless_state_encoding(self.state))
56
+ state = np.array(self._mdp.lossless_state_encoding(self.state), dtype=np.float32)
57
57
  # Use axes (agents, channels, height, width) instead of (agents, height, width, channels)
58
58
  state = np.transpose(state, (0, 3, 1, 2))
59
59
  return state
60
60
 
61
61
  def get_state(self):
62
- return State(self._state_data()[0], np.array([self.time_step / self.horizon]))
62
+ return State(self._state_data()[0], np.array([self.time_step / self.horizon], dtype=np.float32))
63
63
 
64
64
  def get_observation(self) -> Observation:
65
65
  return Observation(
66
66
  data=self._state_data(),
67
67
  available_actions=self.available_actions(),
68
- extras=np.array([[self.time_step / self.horizon]] * self.n_agents),
68
+ extras=np.array([[self.time_step / self.horizon]] * self.n_agents, dtype=np.float32),
69
69
  )
70
70
 
71
71
  def available_actions(self):
@@ -88,6 +88,10 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
88
88
  info=info,
89
89
  )
90
90
 
91
+ def reset(self):
92
+ self._oenv.reset()
93
+ return self.get_observation(), self.get_state()
94
+
91
95
  def __deepcopy__(self, memo: dict):
92
96
  mdp = deepcopy(self._mdp)
93
97
  return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=self.horizon))
@@ -111,7 +115,7 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
111
115
  ]:
112
116
  rewards_dict[key] = value
113
117
 
114
- image = self.visualizer.render_state(
118
+ image = self._visualizer.render_state(
115
119
  state=self._oenv.state,
116
120
  grid=self._mdp.terrain_mtx,
117
121
  hud_data=StateVisualizer.default_hud_data(self._oenv.state, **rewards_dict),
@@ -127,7 +127,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
127
127
 
128
128
  def seed(self, seed_value: int):
129
129
  """Set the environment seed"""
130
- raise NotImplementedError("Method not implemented")
130
+ return
131
131
 
132
132
  @abstractmethod
133
133
  def get_observation(self) -> Observation:
@@ -158,9 +158,9 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
158
158
  """Perform a random step in the environment."""
159
159
  return self.step(self.sample_action())
160
160
 
161
+ @abstractmethod
161
162
  def reset(self) -> tuple[Observation, State]:
162
163
  """Reset the environment and return the initial observation and state."""
163
- return self.get_observation(), self.get_state()
164
164
 
165
165
  def render(self):
166
166
  """Render the environment in a window (or in console)"""
@@ -162,6 +162,10 @@ def test_overcooked_obs_state():
162
162
  height, width = env._mdp.shape
163
163
  obs, state = env.reset()
164
164
  for i in range(HORIZON):
165
+ assert obs.data.dtype == np.float32
166
+ assert state.data.dtype == np.float32
167
+ assert obs.extras.dtype == np.float32
168
+ assert state.extras.dtype == np.float32
165
169
  assert obs.shape == (26, height, width)
166
170
  assert obs.extras_shape == (1,)
167
171
  assert state.shape == (26, height, width)
@@ -189,3 +189,18 @@ env.reset()""")
189
189
  finally:
190
190
  os.remove(f.name)
191
191
  os.remove(env_file.name)
192
+
193
+
194
+ @pytest.mark.skipif(not marlenv.adapters.HAS_OVERCOOKED, reason="Overcooked is not installed")
195
+ def test_serialize_json_overcooked():
196
+ env = marlenv.adapters.Overcooked.from_layout("scenario1_s", horizon=60)
197
+ res = orjson.dumps(env, option=orjson.OPT_SERIALIZE_NUMPY)
198
+ deserialized = orjson.loads(res)
199
+
200
+ assert deserialized["n_agents"] == env.n_agents
201
+ assert tuple(deserialized["observation_shape"]) == env.observation_shape
202
+ assert tuple(deserialized["state_shape"]) == env.state_shape
203
+ assert tuple(deserialized["extras_shape"]) == env.extras_shape
204
+ assert deserialized["n_actions"] == env.n_actions
205
+ assert deserialized["name"] == env.name
206
+ assert deserialized["extras_meanings"] == env.extras_meanings