multi-agent-rlenv 3.6.3__tar.gz → 3.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/.github/workflows/ci.yaml +3 -5
  2. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/PKG-INFO +2 -2
  3. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/pyproject.toml +2 -2
  4. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/__init__.py +2 -2
  5. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/adapters/gym_adapter.py +3 -3
  6. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/adapters/pettingzoo_adapter.py +14 -14
  7. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/adapters/smac_adapter.py +10 -7
  8. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/catalog/deepsea.py +1 -1
  9. multi_agent_rlenv-3.7.0/src/marlenv/catalog/two_steps.py +93 -0
  10. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/env_pool.py +3 -3
  11. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/mock_env.py +2 -2
  12. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/spaces.py +7 -7
  13. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/utils/schedule.py +8 -10
  14. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/agent_id_wrapper.py +2 -2
  15. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/blind_wrapper.py +2 -2
  16. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/centralised.py +3 -3
  17. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/delayed_rewards.py +2 -2
  18. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/last_action_wrapper.py +4 -4
  19. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/paddings.py +4 -4
  20. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/potential_shaping.py +2 -2
  21. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/rlenv_wrapper.py +2 -2
  22. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/time_limit.py +2 -2
  23. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/video_recorder.py +2 -2
  24. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_adapters.py +2 -3
  25. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_models.py +7 -7
  26. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_serialization.py +1 -1
  27. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_wrappers.py +18 -18
  28. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/.github/workflows/docs.yaml +0 -0
  29. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/.gitignore +0 -0
  30. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/LICENSE +0 -0
  31. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/README.md +0 -0
  32. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/adapters/__init__.py +0 -0
  33. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/adapters/pymarl_adapter.py +0 -0
  34. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/catalog/__init__.py +0 -0
  35. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/env_builder.py +0 -0
  36. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/exceptions.py +0 -0
  37. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/__init__.py +0 -0
  38. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/env.py +0 -0
  39. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/episode.py +0 -0
  40. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/observation.py +0 -0
  41. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/state.py +0 -0
  42. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/step.py +0 -0
  43. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/transition.py +0 -0
  44. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/py.typed +0 -0
  45. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/utils/__init__.py +0 -0
  46. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/utils/cached_property_collector.py +0 -0
  47. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/utils/import_placeholders.py +0 -0
  48. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/__init__.py +0 -0
  49. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/action_randomizer.py +0 -0
  50. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/available_actions_mask.py +0 -0
  51. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/available_actions_wrapper.py +0 -0
  52. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/penalty_wrapper.py +0 -0
  53. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/__init__.py +0 -0
  54. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_catalog.py +0 -0
  55. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_deepsea.py +0 -0
  56. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_episode.py +0 -0
  57. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_others.py +0 -0
  58. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_pool.py +0 -0
  59. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_schedules.py +0 -0
  60. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_spaces.py +0 -0
  61. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/utils.py +0 -0
@@ -27,8 +27,6 @@ jobs:
27
27
  - x86_64
28
28
  - aarch64
29
29
  python-version:
30
- - '3.10'
31
- - '3.11'
32
30
  - '3.12'
33
31
  - '3.13'
34
32
  runs-on: ${{ matrix.os }}
@@ -43,7 +41,7 @@ jobs:
43
41
  - name: Install uv
44
42
  uses: yezz123/setup-uv@v4
45
43
  with:
46
- uv-version: 0.6.4
44
+ uv-version: 0.9.24
47
45
  - name: Install dependencies and run pytest
48
46
  run: |
49
47
  uv sync --extra overcooked --extra gym --extra pettingzoo --extra torch
@@ -59,11 +57,11 @@ jobs:
59
57
  - name: Set up Python
60
58
  uses: actions/setup-python@v5
61
59
  with:
62
- python-version: 3.12
60
+ python-version: 3.13
63
61
  - name: Install UV
64
62
  uses: yezz123/setup-uv@v4
65
63
  with:
66
- uv-version: 0.6.4
64
+ uv-version: 0.9.24
67
65
  - name: Build wheels
68
66
  run: |
69
67
  uv venv
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.6.3
3
+ Version: 3.7.0
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
7
7
  License-File: LICENSE
8
8
  Classifier: Operating System :: OS Independent
9
9
  Classifier: Programming Language :: Python :: 3
10
- Requires-Python: <4,>=3.10
10
+ Requires-Python: <4,>=3.12
11
11
  Requires-Dist: numpy>=2.0.0
12
12
  Requires-Dist: opencv-python>=4.0
13
13
  Requires-Dist: typing-extensions>=4.0
@@ -1,12 +1,12 @@
1
1
  [project]
2
2
  name = "multi-agent-rlenv"
3
- version = "3.6.3"
3
+ version = "3.7.0"
4
4
  description = "A strongly typed Multi-Agent Reinforcement Learning framework"
5
5
  authors = [
6
6
  { "name" = "Yannick Molinghen", "email" = "yannick.molinghen@ulb.be" },
7
7
  ]
8
8
  readme = "README.md"
9
- requires-python = ">=3.10, <4"
9
+ requires-python = ">=3.12, <4"
10
10
  urls = { "repository" = "https://github.com/yamoling/multi-agent-rlenv" }
11
11
  classifiers = [
12
12
  "Programming Language :: Python :: 3",
@@ -65,9 +65,9 @@ If you want to create a new environment, you can simply create a class that inhe
65
65
  from importlib.metadata import version, PackageNotFoundError
66
66
 
67
67
  try:
68
- __version__ = version("overcooked")
68
+ __version__ = version("multi-agent-rlenv")
69
69
  except PackageNotFoundError:
70
- __version__ = "0.0.0" # fallback pratique en dev/CI
70
+ __version__ = "0.0.0" # fallback for CI
71
71
 
72
72
 
73
73
  from . import models
@@ -44,8 +44,8 @@ class Gym(MARLEnv[Space]):
44
44
  raise ValueError("No observation available. Call reset() first.")
45
45
  return self._last_obs
46
46
 
47
- def step(self, actions):
48
- obs, reward, done, truncated, info = self._gym_env.step(list(actions)[0])
47
+ def step(self, action):
48
+ obs, reward, done, truncated, info = self._gym_env.step(list(action)[0])
49
49
  self._last_obs = Observation(
50
50
  np.array([obs], dtype=np.float32),
51
51
  self.available_actions(),
@@ -74,7 +74,7 @@ class Gym(MARLEnv[Space]):
74
74
  image = np.array(self._gym_env.render())
75
75
  if sys.platform in ("linux", "linux2"):
76
76
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
77
- return image
77
+ return np.array(image, dtype=np.uint8)
78
78
 
79
79
  def seed(self, seed_value: int):
80
80
  self._gym_env.reset(seed=seed_value)
@@ -33,39 +33,39 @@ class PettingZoo(MARLEnv[Space]):
33
33
  if obs_space.shape is None:
34
34
  raise NotImplementedError("Only discrete observation spaces are supported")
35
35
  self._pz_env = env
36
- env.reset()
37
- super().__init__(n_agents, space, obs_space.shape, self.get_state().shape)
36
+ self.n_agents = n_agents
37
+ self.n_actions = space.shape[-1]
38
+ self.last_observation, state = self.reset()
39
+ super().__init__(n_agents, space, obs_space.shape, state.shape)
38
40
  self.agents = env.possible_agents
39
- self.last_observation = None
40
41
 
41
42
  def get_state(self):
42
43
  try:
43
- return self._pz_env.state()
44
+ return State(self._pz_env.state())
44
45
  except NotImplementedError:
45
- return np.array([0])
46
+ assert self.last_observation is not None, "Cannot get the state unless there is a previous observation"
47
+ return State(self.last_observation.data)
46
48
 
47
- def step(self, actions: npt.NDArray | Sequence):
48
- action_dict = dict(zip(self.agents, actions))
49
+ def step(self, action: npt.NDArray | Sequence):
50
+ action_dict = dict(zip(self.agents, action))
49
51
  obs, reward, term, trunc, info = self._pz_env.step(action_dict)
50
52
  obs_data = np.array([v for v in obs.values()])
51
53
  reward = np.sum([r for r in reward.values()], keepdims=True)
52
54
  self.last_observation = Observation(obs_data, self.available_actions())
53
- state = State(self.get_state())
55
+ state = self.get_state()
54
56
  return Step(self.last_observation, state, reward, any(term.values()), any(trunc.values()), info)
55
57
 
56
58
  def reset(self):
57
59
  obs = self._pz_env.reset()[0]
58
60
  obs_data = np.array([v for v in obs.values()])
59
- self.last_observation = Observation(obs_data, self.available_actions(), self.get_state())
60
- return self.last_observation
61
+ self.last_observation = Observation(obs_data, self.available_actions())
62
+ return self.last_observation, self.get_state()
61
63
 
62
64
  def get_observation(self):
63
- if self.last_observation is None:
64
- raise ValueError("No observation available. Call reset() first.")
65
65
  return self.last_observation
66
66
 
67
67
  def seed(self, seed_value: int):
68
68
  self._pz_env.reset(seed=seed_value)
69
69
 
70
- def render(self, *_):
71
- return self._pz_env.render()
70
+ def render(self):
71
+ self._pz_env.render()
@@ -3,7 +3,7 @@ from typing import overload
3
3
 
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
- from smac.env import StarCraft2Env
6
+ from smac.env import StarCraft2Env # type:ignore[import]
7
7
 
8
8
  from marlenv.models import MARLEnv, Observation, State, Step, MultiDiscreteSpace, DiscreteSpace
9
9
 
@@ -169,17 +169,18 @@ class SMAC(MARLEnv[MultiDiscreteSpace]):
169
169
 
170
170
  def reset(self):
171
171
  obs, state = self._env.reset()
172
- obs = Observation(np.array(obs), self.available_actions(), state)
173
- return obs
172
+ obs = Observation(np.array(obs), self.available_actions())
173
+ state = State(state)
174
+ return obs, state
174
175
 
175
176
  def get_observation(self):
176
- return self._env.get_obs()
177
+ return Observation(np.array(self._env.get_obs()), self.available_actions())
177
178
 
178
179
  def get_state(self):
179
180
  return State(self._env.get_state())
180
181
 
181
- def step(self, actions):
182
- reward, done, info = self._env.step(actions)
182
+ def step(self, action):
183
+ reward, done, info = self._env.step(action)
183
184
  obs = Observation(
184
185
  self._env.get_obs(), # type: ignore
185
186
  self.available_actions(),
@@ -199,7 +200,9 @@ class SMAC(MARLEnv[MultiDiscreteSpace]):
199
200
  return np.array(self._env.get_avail_actions()) == 1
200
201
 
201
202
  def get_image(self):
202
- return self._env.render(mode="rgb_array")
203
+ img = self._env.render(mode="rgb_array")
204
+ assert img is not None
205
+ return img
203
206
 
204
207
  def seed(self, seed_value: int):
205
208
  self._env = StarCraft2Env(map_name=self._env.map_name, seed=seed_value)
@@ -45,7 +45,7 @@ class DeepSea(MARLEnv[MultiDiscreteSpace]):
45
45
  self._col = 0
46
46
  return self.get_observation(), self.get_state()
47
47
 
48
- def step(self, action: Sequence[int]):
48
+ def step(self, action: Sequence[int] | np.ndarray):
49
49
  self._row += 1
50
50
  if action[0] == LEFT:
51
51
  self._col -= 1
@@ -0,0 +1,93 @@
1
+ from enum import IntEnum
2
+ import cv2
3
+ import marlenv
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ from typing import Sequence
7
+ from marlenv import Observation, State, DiscreteSpace, Step
8
+
9
+ PAYOFF_INITIAL = [[0, 0], [0, 0]]
10
+ PAYOFF_2A = [[7, 7], [7, 7]]
11
+ PAYOFF_2B = [[0, 1], [1, 8]]
12
+
13
+
14
+ class TwoStepsState(IntEnum):
15
+ INITIAL = 0
16
+ STATE_2A = 1
17
+ STATE_2B = 2
18
+ END = 3
19
+
20
+ def one_hot(self):
21
+ res = np.zeros((4,), dtype=np.float32)
22
+ res[self.value] = 1
23
+ return res
24
+
25
+ @staticmethod
26
+ def from_one_hot(x: np.ndarray):
27
+ for s in TwoStepsState:
28
+ if x[s.value] == 1:
29
+ return s
30
+ raise ValueError()
31
+
32
+
33
+ class TwoStepsGame(marlenv.MARLEnv):
34
+ """
35
+ Two-steps game used in QMix paper (https://arxiv.org/pdf/1803.11485.pdf, section 5)
36
+ to demonstrate its superior representationability compared to VDN.
37
+ """
38
+
39
+ def __init__(self):
40
+ self.state = TwoStepsState.INITIAL
41
+ self._identity = np.identity(2, dtype=np.float32)
42
+ super().__init__(
43
+ 2,
44
+ DiscreteSpace(2).repeat(2),
45
+ observation_shape=(self.state.one_hot().shape[0] + 2,),
46
+ state_shape=self.state.one_hot().shape,
47
+ )
48
+
49
+ def reset(self):
50
+ self.state = TwoStepsState.INITIAL
51
+ return self.observation(), self.get_state()
52
+
53
+ def step(self, action: npt.NDArray[np.int32] | Sequence):
54
+ match self.state:
55
+ case TwoStepsState.INITIAL:
56
+ # In the initial step, only agent 0's actions have an influence on the state
57
+ payoffs = PAYOFF_INITIAL
58
+ if action[0] == 0:
59
+ self.state = TwoStepsState.STATE_2A
60
+ elif action[0] == 1:
61
+ self.state = TwoStepsState.STATE_2B
62
+ else:
63
+ raise ValueError(f"Invalid action: {action[0]}")
64
+ case TwoStepsState.STATE_2A:
65
+ payoffs = PAYOFF_2A
66
+ self.state = TwoStepsState.END
67
+ case TwoStepsState.STATE_2B:
68
+ payoffs = PAYOFF_2B
69
+ self.state = TwoStepsState.END
70
+ case TwoStepsState.END:
71
+ raise ValueError("Episode is already over")
72
+ reward = payoffs[action[0]][action[1]]
73
+ done = self.state == TwoStepsState.END
74
+ return Step(self.observation(), self.get_state(), reward, done, False)
75
+
76
+ def get_state(self):
77
+ return State(self.state.one_hot())
78
+
79
+ def observation(self):
80
+ obs_data = np.array([self.state.one_hot(), self.state.one_hot()])
81
+ extras = self._identity
82
+ return Observation(obs_data, self.available_actions(), extras)
83
+
84
+ def render(self):
85
+ print(self.state)
86
+
87
+ def get_image(self):
88
+ state = self.state.one_hot()
89
+ img = cv2.cvtColor(state, cv2.COLOR_GRAY2BGR)
90
+ return np.array(img, dtype=np.uint8)
91
+
92
+ def set_state(self, state: State):
93
+ self.state = TwoStepsState.from_one_hot(state.data)
@@ -20,10 +20,10 @@ class EnvPool(RLEnvWrapper[ActionSpaceType]):
20
20
  assert env.has_same_inouts(self.envs[0]), "All environments must have the same inputs and outputs"
21
21
  super().__init__(self.envs[0])
22
22
 
23
- def seed(self, seed: int):
24
- random.seed(seed)
23
+ def seed(self, seed_value: int):
24
+ random.seed(seed_value)
25
25
  for env in self.envs:
26
- env.seed(seed)
26
+ env.seed(seed_value)
27
27
 
28
28
  def reset(self):
29
29
  self.wrapped = random.choice(self.envs)
@@ -73,9 +73,9 @@ class DiscreteMockEnv(MARLEnv[MultiDiscreteSpace]):
73
73
  def render(self, mode: str = "human"):
74
74
  return
75
75
 
76
- def step(self, actions):
76
+ def step(self, action):
77
77
  self.t += 1
78
- self.actions_history.append(actions)
78
+ self.actions_history.append(action)
79
79
  return Step(
80
80
  self.get_observation(),
81
81
  self.get_state(),
@@ -8,7 +8,7 @@ import numpy.typing as npt
8
8
 
9
9
 
10
10
  @dataclass
11
- class Space(ABC):
11
+ class Space[T](ABC):
12
12
  shape: tuple[int, ...]
13
13
  size: int
14
14
  labels: list[str]
@@ -21,7 +21,7 @@ class Space(ABC):
21
21
  self.labels = labels
22
22
 
23
23
  @abstractmethod
24
- def sample(self, mask: Optional[npt.NDArray[np.bool_]] = None) -> npt.NDArray[np.float32]:
24
+ def sample(self, mask: npt.NDArray[np.bool] | None = None) -> T:
25
25
  """Sample a value from the space."""
26
26
 
27
27
  def __eq__(self, value: object) -> bool:
@@ -44,7 +44,7 @@ class Space(ABC):
44
44
 
45
45
 
46
46
  @dataclass
47
- class DiscreteSpace(Space):
47
+ class DiscreteSpace(Space[int]):
48
48
  size: int
49
49
  """Number of categories"""
50
50
 
@@ -53,7 +53,7 @@ class DiscreteSpace(Space):
53
53
  self.size = size
54
54
  self.space = np.arange(size)
55
55
 
56
- def sample(self, mask: Optional[npt.NDArray[np.bool]] = None):
56
+ def sample(self, mask: npt.NDArray[np.bool] | None = None):
57
57
  space = self.space.copy()
58
58
  if mask is not None:
59
59
  space = space[mask]
@@ -87,7 +87,7 @@ class DiscreteSpace(Space):
87
87
 
88
88
 
89
89
  @dataclass
90
- class MultiDiscreteSpace(Space):
90
+ class MultiDiscreteSpace(Space[npt.NDArray[np.int32]]):
91
91
  n_dims: int
92
92
  spaces: tuple[DiscreteSpace, ...]
93
93
 
@@ -123,7 +123,7 @@ class MultiDiscreteSpace(Space):
123
123
 
124
124
 
125
125
  @dataclass
126
- class ContinuousSpace(Space):
126
+ class ContinuousSpace(Space[npt.NDArray[np.float32]]):
127
127
  """A continuous space (box) in R^n."""
128
128
 
129
129
  low: npt.NDArray[np.float32]
@@ -192,7 +192,7 @@ class ContinuousSpace(Space):
192
192
  action = np.array(action)
193
193
  return np.clip(action, self.low, self.high)
194
194
 
195
- def sample(self) -> npt.NDArray[np.float32]:
195
+ def sample(self, *args, **kwargs):
196
196
  r = np.random.random(self.shape) * (self.high - self.low) + self.low
197
197
  return r.astype(np.float32)
198
198
 
@@ -145,17 +145,15 @@ class Schedule:
145
145
  @staticmethod
146
146
  def from_json(data: dict[str, Any]):
147
147
  """Create a Schedule from a JSON-like dictionary."""
148
- classname = data.get("name")
149
- if classname == "LinearSchedule":
150
- return LinearSchedule(data["start_value"], data["end_value"], data["n_steps"])
151
- elif classname == "ExpSchedule":
152
- return ExpSchedule(data["start_value"], data["end_value"], data["n_steps"])
153
- elif classname == "ConstantSchedule":
154
- return ConstantSchedule(data["value"])
155
- elif classname == "ArbitrarySchedule":
148
+ candidates = [LinearSchedule, ExpSchedule, ConstantSchedule]
149
+ data = data.copy()
150
+ classname = data.pop("name")
151
+ for cls in candidates:
152
+ if cls.__name__ == classname:
153
+ return cls(**data)
154
+ if classname == "ArbitrarySchedule":
156
155
  raise NotImplementedError("ArbitrarySchedule cannot be deserialized from JSON")
157
- else:
158
- raise ValueError(f"Unknown schedule type: {classname}")
156
+ raise ValueError(f"Unknown schedule type: {classname}")
159
157
 
160
158
 
161
159
  @dataclass(eq=False)
@@ -18,8 +18,8 @@ class AgentId(RLEnvWrapper[AS]):
18
18
  super().__init__(env, extra_shape=(env.n_agents + env.extras_shape[0],), extra_meanings=meanings)
19
19
  self._identity = np.identity(env.n_agents, dtype=np.float32)
20
20
 
21
- def step(self, actions):
22
- step = super().step(actions)
21
+ def step(self, action):
22
+ step = super().step(action)
23
23
  step.obs.add_extra(self._identity)
24
24
  return step
25
25
 
@@ -18,8 +18,8 @@ class Blind(RLEnvWrapper[AS]):
18
18
  super().__init__(env)
19
19
  self.p = float(p)
20
20
 
21
- def step(self, actions):
22
- step = super().step(actions)
21
+ def step(self, action):
22
+ step = super().step(action)
23
23
  if random.random() < self.p:
24
24
  step.obs.data = np.zeros_like(step.obs.data)
25
25
  return step
@@ -42,9 +42,9 @@ class Centralized(RLEnvWrapper[MultiDiscreteSpace]):
42
42
  action_names = [str(a) for a in product(*agent_actions)]
43
43
  return DiscreteSpace(env.n_actions**env.n_agents, action_names).repeat(1)
44
44
 
45
- def step(self, actions: npt.NDArray | Sequence):
46
- action = actions[0]
47
- individual_actions = self._individual_actions(action)
45
+ def step(self, action: npt.NDArray | Sequence):
46
+ action1 = action[0]
47
+ individual_actions = self._individual_actions(action1)
48
48
  individual_actions = np.array(individual_actions)
49
49
  step = self.wrapped.step(individual_actions) # type: ignore
50
50
  step.obs = self._joint_observation(step.obs)
@@ -27,8 +27,8 @@ class DelayedReward(RLEnvWrapper[AS]):
27
27
  self.reward_queue.append(np.zeros(self.reward_space.shape, dtype=np.float32))
28
28
  return super().reset()
29
29
 
30
- def step(self, actions):
31
- step = super().step(actions)
30
+ def step(self, action):
31
+ step = super().step(action)
32
32
  self.reward_queue.append(step.reward)
33
33
  # If the step is terminal, we sum all the remaining rewards
34
34
  if step.is_terminal:
@@ -33,13 +33,13 @@ class LastAction(RLEnvWrapper[AS]):
33
33
  state.add_extra(self.last_one_hot_actions.flatten())
34
34
  return obs, state
35
35
 
36
- def step(self, actions):
37
- step = super().step(actions)
36
+ def step(self, action):
37
+ step = super().step(action)
38
38
  match self.wrapped.action_space:
39
39
  case ContinuousSpace():
40
- self.last_actions = actions
40
+ self.last_actions = action
41
41
  case DiscreteSpace() | MultiDiscreteSpace():
42
- self.last_one_hot_actions = self.compute_one_hot_actions(actions)
42
+ self.last_one_hot_actions = self.compute_one_hot_actions(action)
43
43
  case other:
44
44
  raise NotImplementedError(f"Action space {other} not supported")
45
45
  step.obs.add_extra(self.last_one_hot_actions)
@@ -24,8 +24,8 @@ class PadExtras(RLEnvWrapper[AS]):
24
24
  )
25
25
  self.n = n_added
26
26
 
27
- def step(self, actions):
28
- step = super().step(actions)
27
+ def step(self, action):
28
+ step = super().step(action)
29
29
  step.obs = self._add_extras(step.obs)
30
30
  return step
31
31
 
@@ -48,8 +48,8 @@ class PadObservations(RLEnvWrapper[AS]):
48
48
  super().__init__(env, observation_shape=(env.observation_shape[0] + n_added,))
49
49
  self.n = n_added
50
50
 
51
- def step(self, actions):
52
- step = super().step(actions)
51
+ def step(self, action):
52
+ step = super().step(action)
53
53
  step.obs = self._add_obs(step.obs)
54
54
  return step
55
55
 
@@ -39,9 +39,9 @@ class PotentialShaping(RLEnvWrapper[A], ABC):
39
39
  self._current_potential = self.compute_potential()
40
40
  return self.add_extras(obs), state
41
41
 
42
- def step(self, actions):
42
+ def step(self, action):
43
43
  prev_potential = self._current_potential
44
- step = super().step(actions)
44
+ step = super().step(action)
45
45
 
46
46
  self._current_potential = self.compute_potential()
47
47
  shaped_reward = self.gamma * self._current_potential - prev_potential
@@ -62,8 +62,8 @@ class RLEnvWrapper(MARLEnv[AS]):
62
62
  def agent_state_size(self):
63
63
  return self.wrapped.agent_state_size
64
64
 
65
- def step(self, actions: np.ndarray | Sequence):
66
- return self.wrapped.step(actions)
65
+ def step(self, action: np.ndarray | Sequence):
66
+ return self.wrapped.step(action)
67
67
 
68
68
  def reset(self):
69
69
  return self.wrapped.reset()
@@ -64,9 +64,9 @@ class TimeLimit(RLEnvWrapper[AS]):
64
64
  self.add_time_extra(obs, state)
65
65
  return obs, state
66
66
 
67
- def step(self, actions):
67
+ def step(self, action):
68
68
  self._current_step += 1
69
- step = super().step(actions)
69
+ step = super().step(action)
70
70
  if self.add_extra:
71
71
  self.add_time_extra(step.obs, step.state)
72
72
  # If we reach the time limit
@@ -44,10 +44,10 @@ class VideoRecorder(RLEnvWrapper[AS]):
44
44
  case other:
45
45
  raise ValueError(f"Unsupported file video encoding: {other}")
46
46
 
47
- def step(self, actions):
47
+ def step(self, action):
48
48
  if self._recorder is None:
49
49
  raise RuntimeError("VideoRecorder not initialized")
50
- step = super().step(actions)
50
+ step = super().step(action)
51
51
  img = self.get_image()
52
52
  self._recorder.write(img)
53
53
  if step.is_terminal:
@@ -98,7 +98,7 @@ def _check_env_3m(env):
98
98
  from marlenv.adapters import SMAC
99
99
 
100
100
  assert isinstance(env, SMAC)
101
- obs = env.reset()
101
+ obs, state = env.reset()
102
102
  assert isinstance(obs, Observation)
103
103
  assert env.n_agents == 3
104
104
  assert isinstance(env.action_space, MultiDiscreteSpace)
@@ -114,8 +114,7 @@ def _check_env_3m(env):
114
114
 
115
115
  @pytest.mark.skipif(skip_smac, reason="SMAC is not installed")
116
116
  def test_smac_from_class():
117
- from smac.env import StarCraft2Env
118
-
117
+ from smac.env import StarCraft2Env # type: ignore[import]
119
118
  from marlenv.adapters import SMAC
120
119
 
121
120
  env = SMAC(StarCraft2Env("3m"))
@@ -380,8 +380,8 @@ def test_env_replay():
380
380
  available[(agent + self._seed) % self.n_actions] = True
381
381
  return availables
382
382
 
383
- def step(self, actions):
384
- return super().step(actions)
383
+ def step(self, action):
384
+ return super().step(action)
385
385
 
386
386
  def seed(self, seed_value: int):
387
387
  np.random.seed(seed_value)
@@ -409,16 +409,16 @@ def test_wrong_extras_meanings_length():
409
409
  super().__init__(4, DiscreteSpace(5), (10,), (10,), extras_shape=(5,), extras_meanings=["a", "b", "c"])
410
410
 
411
411
  def get_observation(self):
412
- pass
412
+ raise NotImplementedError()
413
413
 
414
414
  def get_state(self):
415
- pass
415
+ raise NotImplementedError()
416
416
 
417
- def step(self, actions):
418
- pass
417
+ def step(self, action):
418
+ raise NotImplementedError()
419
419
 
420
420
  def reset(self):
421
- pass
421
+ raise NotImplementedError()
422
422
 
423
423
  try:
424
424
  TestClass()
@@ -241,7 +241,7 @@ def test_serialize_schedule():
241
241
  try:
242
242
  pickle.dumps(s)
243
243
  assert False, "Should not be able to pickle arbitrary schedules because of the callable lambda"
244
- except AttributeError:
244
+ except (pickle.PicklingError, AttributeError):
245
245
  pass
246
246
 
247
247
  s = Schedule.arbitrary(C())
@@ -1,6 +1,6 @@
1
1
  import numpy as np
2
2
  from marlenv import Builder, DiscreteMOMockEnv, DiscreteMockEnv, MARLEnv
3
- from marlenv.wrappers import Centralized, AvailableActionsMask, TimeLimit, LastAction, DelayedReward, ActionRandomizer
3
+ from marlenv.wrappers import Centralized, AvailableActionsMask, TimeLimit, LastAction, DelayedReward
4
4
  import marlenv
5
5
 
6
6
 
@@ -55,13 +55,12 @@ def test_time_limit_wrapper():
55
55
  env = Builder(DiscreteMockEnv(1)).time_limit(MAX_T).build()
56
56
  assert env.extras_shape == (1,)
57
57
  assert env.state_extra_shape == (1,)
58
- done = False
59
- t = 0
60
- while not done:
61
- step = env.step(np.array([0]))
58
+ t = 1
59
+ step = env.step(np.array([0]))
60
+ while not step.done:
62
61
  assert step.obs.extras.shape == (env.n_agents, 1)
63
62
  assert step.state.extras_shape == (1,)
64
- done = step.done
63
+ step = env.step(np.array([0]))
65
64
  t += 1
66
65
  assert t == MAX_T
67
66
  assert step.truncated
@@ -73,12 +72,15 @@ def test_truncated_and_done():
73
72
  env = marlenv.wrappers.TimeLimit(DiscreteMockEnv(2, end_game=END_GAME), END_GAME)
74
73
  obs, state = env.reset()
75
74
  episode = marlenv.Episode.new(obs, state)
75
+ action = env.action_space.sample()
76
+ step = env.step(action)
76
77
  while not episode.is_finished:
77
- action = env.action_space.sample()
78
- step = env.step(action)
79
78
  episode.add(marlenv.Transition.from_step(obs, state, action, step))
80
79
  obs = step.obs
81
80
  state = step.state
81
+ action = env.action_space.sample()
82
+ step = env.step(action)
83
+
82
84
  assert step.done
83
85
  assert not step.truncated, (
84
86
  "The episode is done, so it does not have to be truncated even though the time limit is reached at the same time."
@@ -97,11 +99,10 @@ def test_time_limit_wrapper_with_extra():
97
99
  assert env.extras_shape == (1,)
98
100
  obs, _ = env.reset()
99
101
  assert obs.extras.shape == (5, 1)
100
- stop = False
101
- t = 0
102
- while not stop:
102
+ t = 1
103
+ step = env.step(np.array([0]))
104
+ while not step.is_terminal:
103
105
  step = env.step(np.array([0]))
104
- stop = step.done or step.truncated
105
106
  t += 1
106
107
  assert t == MAX_T
107
108
  assert np.all(step.obs.extras == 1.0)
@@ -129,11 +130,10 @@ def test_time_limit_wrapper_with_truncation_penalty():
129
130
  assert env.extras_shape == (1,)
130
131
  obs, _ = env.reset()
131
132
  assert obs.extras.shape == (5, 1)
132
- stop = False
133
- t = 0
134
- while not stop:
133
+ t = 1
134
+ step = env.step(np.array([0]))
135
+ while not step.is_terminal:
135
136
  step = env.step(np.array([0]))
136
- stop = step.done or step.truncated
137
137
  t += 1
138
138
  assert t == MAX_T
139
139
  assert np.all(step.obs.extras[:] == 1)
@@ -374,9 +374,9 @@ def test_potential_shaping():
374
374
  def compute_potential(self) -> float:
375
375
  return self.phi
376
376
 
377
- def step(self, actions):
377
+ def step(self, action):
378
378
  self.phi = max(0, self.phi - 1)
379
- return super().step(actions)
379
+ return super().step(action)
380
380
 
381
381
  EP_LENGTH = 20
382
382
  env = PS(DiscreteMockEnv(reward_step=0, end_game=EP_LENGTH))