multi-agent-rlenv 3.5.1__tar.gz → 3.5.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/.github/workflows/ci.yaml +1 -1
  2. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/PKG-INFO +4 -1
  3. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/pyproject.toml +2 -0
  4. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/__init__.py +1 -1
  5. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/models/env.py +1 -1
  6. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/models/episode.py +10 -55
  7. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/models/observation.py +10 -0
  8. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/models/state.py +8 -0
  9. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/potential_shaping.py +11 -6
  10. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/tests/test_models.py +34 -1
  11. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/.github/workflows/docs.yaml +0 -0
  12. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/.gitignore +0 -0
  13. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/LICENSE +0 -0
  14. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/README.md +0 -0
  15. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/adapters/__init__.py +0 -0
  16. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/adapters/gym_adapter.py +0 -0
  17. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/adapters/overcooked_adapter.py +0 -0
  18. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/adapters/pettingzoo_adapter.py +0 -0
  19. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/adapters/pymarl_adapter.py +0 -0
  20. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/adapters/smac_adapter.py +0 -0
  21. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/env_builder.py +0 -0
  22. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/env_pool.py +0 -0
  23. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/exceptions.py +0 -0
  24. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/mock_env.py +0 -0
  25. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/models/__init__.py +0 -0
  26. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/models/spaces.py +0 -0
  27. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/models/step.py +0 -0
  28. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/models/transition.py +0 -0
  29. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/py.typed +0 -0
  30. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/utils/__init__.py +0 -0
  31. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/utils/schedule.py +0 -0
  32. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/__init__.py +0 -0
  33. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/agent_id_wrapper.py +0 -0
  34. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/available_actions_mask.py +0 -0
  35. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/available_actions_wrapper.py +0 -0
  36. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/blind_wrapper.py +0 -0
  37. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/centralised.py +0 -0
  38. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/delayed_rewards.py +0 -0
  39. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/last_action_wrapper.py +0 -0
  40. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/paddings.py +0 -0
  41. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/penalty_wrapper.py +0 -0
  42. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/rlenv_wrapper.py +0 -0
  43. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/time_limit.py +0 -0
  44. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/src/marlenv/wrappers/video_recorder.py +0 -0
  45. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/tests/__init__.py +0 -0
  46. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/tests/test_adapters.py +0 -0
  47. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/tests/test_episode.py +0 -0
  48. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/tests/test_pool.py +0 -0
  49. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/tests/test_schedules.py +0 -0
  50. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/tests/test_serialization.py +0 -0
  51. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/tests/test_spaces.py +0 -0
  52. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/tests/test_wrappers.py +0 -0
  53. {multi_agent_rlenv-3.5.1 → multi_agent_rlenv-3.5.4}/tests/utils.py +0 -0
@@ -46,7 +46,7 @@ jobs:
46
46
  uv-version: 0.6.4
47
47
  - name: Install dependencies and run pytest
48
48
  run: |
49
- uv sync --extra overcooked --extra gym --extra pettingzoo
49
+ uv sync --extra overcooked --extra gym --extra pettingzoo --extra torch
50
50
  uv run pytest
51
51
 
52
52
  build:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.5.1
3
+ Version: 3.5.4
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
@@ -19,6 +19,7 @@ Requires-Dist: pymunk>=6.0; extra == 'all'
19
19
  Requires-Dist: pysc2; extra == 'all'
20
20
  Requires-Dist: scipy>=1.10; extra == 'all'
21
21
  Requires-Dist: smac; extra == 'all'
22
+ Requires-Dist: torch>=2.0; extra == 'all'
22
23
  Provides-Extra: gym
23
24
  Requires-Dist: gymnasium>=0.29.1; extra == 'gym'
24
25
  Provides-Extra: overcooked
@@ -31,6 +32,8 @@ Requires-Dist: scipy>=1.10; extra == 'pettingzoo'
31
32
  Provides-Extra: smac
32
33
  Requires-Dist: pysc2; extra == 'smac'
33
34
  Requires-Dist: smac; extra == 'smac'
35
+ Provides-Extra: torch
36
+ Requires-Dist: torch>=2.0; extra == 'torch'
34
37
  Description-Content-Type: text/markdown
35
38
 
36
39
  # `marlenv` - A unified framework for muti-agent reinforcement learning
@@ -20,6 +20,7 @@ gym = ["gymnasium>=0.29.1"]
20
20
  smac = ["smac", "pysc2"]
21
21
  pettingzoo = ["pettingzoo>=1.20", "pymunk>=6.0", "scipy>=1.10"]
22
22
  overcooked = ["overcooked-ai>=1.1.0", "scipy>=1.10"]
23
+ torch = ["torch>=2.0"]
23
24
  all = [
24
25
  "gymnasium>0.29.1",
25
26
  "pettingzoo>=1.20",
@@ -28,6 +29,7 @@ all = [
28
29
  "pysc2",
29
30
  "pymunk>=6.0",
30
31
  "scipy>=1.10",
32
+ "torch>=2.0",
31
33
  ]
32
34
 
33
35
  [build-system]
@@ -62,7 +62,7 @@ print(env.extras_shape) # (1, )
62
62
  If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
63
63
  """
64
64
 
65
- __version__ = "3.5.1"
65
+ __version__ = "3.5.4"
66
66
 
67
67
  from . import models
68
68
  from .models import (
@@ -199,7 +199,7 @@ class MARLEnv(ABC, Generic[ActionSpaceType]):
199
199
  episode.add(step, action)
200
200
  return episode
201
201
 
202
- def has_same_inouts(self, other) -> bool:
202
+ def has_same_inouts(self, other: "MARLEnv[ActionSpaceType]") -> bool:
203
203
  """Alias for `have_same_inouts(self, other)`."""
204
204
  if not isinstance(other, MARLEnv):
205
205
  return False
@@ -22,10 +22,10 @@ class Episode:
22
22
  all_extras: list[npt.NDArray[np.float32]]
23
23
  actions: list[npt.NDArray]
24
24
  rewards: list[npt.NDArray[np.float32]]
25
- all_available_actions: list[npt.NDArray[np.bool_]]
25
+ all_available_actions: list[npt.NDArray[np.bool]]
26
26
  all_states: list[npt.NDArray[np.float32]]
27
27
  all_states_extras: list[npt.NDArray[np.float32]]
28
- metrics: dict[str, float]
28
+ metrics: dict[str, Any]
29
29
  episode_len: int
30
30
  other: dict[str, list[Any]]
31
31
  is_done: bool = False
@@ -33,7 +33,7 @@ class Episode:
33
33
  """Whether the episode did reach a terminal state (different from truncated)"""
34
34
 
35
35
  @staticmethod
36
- def new(obs: Observation, state: State, metrics: Optional[dict[str, float]] = None) -> "Episode":
36
+ def new(obs: Observation, state: State, metrics: Optional[dict[str, Any]] = None) -> "Episode":
37
37
  if metrics is None:
38
38
  metrics = {}
39
39
  return Episode(
@@ -66,13 +66,13 @@ class Episode:
66
66
  if target_len < self.episode_len:
67
67
  raise ValueError(f"Cannot pad episode to a smaller size: {target_len} < {self.episode_len}")
68
68
  padding_size = target_len - self.episode_len
69
- obs = self.all_observations + [self.all_observations[0]] * padding_size
70
- extras = self.all_extras + [self.all_extras[0]] * padding_size
71
- actions = self.actions + [self.actions[0]] * padding_size
72
- rewards = self.rewards + [self.rewards[0]] * padding_size
69
+ obs = self.all_observations + [np.zeros_like(self.all_observations[0])] * padding_size
70
+ extras = self.all_extras + [np.zeros_like(self.all_extras[0])] * padding_size
71
+ actions = self.actions + [np.zeros_like(self.actions[0])] * padding_size
72
+ rewards = self.rewards + [np.zeros_like(self.rewards[0])] * padding_size
73
73
  availables = self.all_available_actions + [self.all_available_actions[0]] * padding_size
74
- states = self.all_states + [self.all_states[0]] * padding_size
75
- states_extras = self.all_states_extras + [self.all_states_extras[0]] * padding_size
74
+ states = self.all_states + [np.zeros_like(self.all_states[0])] * padding_size
75
+ states_extras = self.all_states_extras + [np.zeros_like(self.all_states_extras[0])] * padding_size
76
76
  other = {key: value + [value[0]] * padding_size for key, value in self.other.items()}
77
77
  return Episode(
78
78
  all_observations=obs,
@@ -363,51 +363,6 @@ class Episode:
363
363
  for i, s in enumerate(scores):
364
364
  self.metrics[f"score-{i}"] = float(s)
365
365
 
366
- # def add_data(
367
- # self,
368
- # new_obs: Observation,
369
- # new_state: State,
370
- # action: A,
371
- # reward: np.ndarray,
372
- # done: bool,
373
- # truncated: bool,
374
- # info: dict[str, Any],
375
- # **kwargs,
376
- # ):
377
- # """Add a new transition to the episode"""
378
- # self.episode_len += 1
379
- # self.all_observations.append(new_obs.data)
380
- # self.all_extras.append(new_obs.extras)
381
- # self.all_available_actions.append(new_obs.available_actions)
382
- # self.all_states.append(new_state.data)
383
- # self.all_states_extras.append(new_state.extras)
384
- # match action:
385
- # case np.ndarray() as action:
386
- # self.actions.append(action)
387
- # case other:
388
- # self.actions.append(np.array(other))
389
- # self.rewards.append(reward)
390
- # for key, value in kwargs.items():
391
- # current = self.other.get(key, [])
392
- # current.append(value)
393
- # self.other[key] = current
394
-
395
- # if done:
396
- # # Only set the truncated flag if the episode is not done (both could happen with a time limit)
397
- # self.is_truncated = truncated
398
- # self.is_done = done
399
- # # Add metrics that can be plotted
400
- # for key, value in info.items():
401
- # if isinstance(value, bool):
402
- # value = int(value)
403
- # self.metrics[key] = value
404
- # self.metrics["episode_len"] = self.episode_len
405
-
406
- # rewards = np.array(self.rewards)
407
- # scores = np.sum(rewards, axis=0)
408
- # for i, s in enumerate(scores):
409
- # self.metrics[f"score-{i}"] = float(s)
410
-
411
- def add_metrics(self, metrics: dict[str, float]):
366
+ def add_metrics(self, metrics: dict[str, Any]):
412
367
  """Add metrics to the episode"""
413
368
  self.metrics.update(metrics)
@@ -87,3 +87,13 @@ class Observation:
87
87
  if not np.array_equal(self.data, other.data):
88
88
  return False
89
89
  return np.array_equal(self.extras, other.extras) and np.array_equal(self.available_actions, other.available_actions)
90
+
91
+ def as_tensors(self, device=None):
92
+ """
93
+ Convert the observation to a tuple of tensors of shape (1, n_agents, <dim>).
94
+ """
95
+ import torch
96
+
97
+ data = torch.from_numpy(self.data).unsqueeze(0).to(device, non_blocking=True)
98
+ extras = torch.from_numpy(self.extras).unsqueeze(0).to(device, non_blocking=True)
99
+ return data, extras
@@ -52,3 +52,11 @@ class State(Generic[StateType]):
52
52
  if not np.array_equal(self.extras, value.extras):
53
53
  return False
54
54
  return True
55
+
56
+ def as_tensors(self, device=None):
57
+ """Convert the state to a tuple of tensors of shape (1, <dim>)."""
58
+ import torch
59
+
60
+ data = torch.from_numpy(self.data).unsqueeze(0).to(device, non_blocking=True)
61
+ extras = torch.from_numpy(self.extras).unsqueeze(0).to(device, non_blocking=True)
62
+ return data, extras
@@ -2,10 +2,15 @@ from abc import abstractmethod, ABC
2
2
  from .rlenv_wrapper import RLEnvWrapper
3
3
  from marlenv import Space, MARLEnv, Observation
4
4
  from typing import TypeVar, Optional
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+
8
+ from dataclasses import dataclass
5
9
 
6
10
  A = TypeVar("A", bound=Space)
7
11
 
8
12
 
13
+ @dataclass
9
14
  class PotentialShaping(RLEnvWrapper[A], ABC):
10
15
  """
11
16
  Potential shaping for the Laser Learning Environment (LLE).
@@ -23,7 +28,7 @@ class PotentialShaping(RLEnvWrapper[A], ABC):
23
28
  ):
24
29
  super().__init__(env, extra_shape=extra_shape)
25
30
  self.gamma = gamma
26
- self.current_potential = self.compute_potential()
31
+ self._current_potential = self.compute_potential()
27
32
 
28
33
  def add_extras(self, obs: Observation) -> Observation:
29
34
  """Add the extras related to potential shaping. Does nothing by default."""
@@ -31,19 +36,19 @@ class PotentialShaping(RLEnvWrapper[A], ABC):
31
36
 
32
37
  def reset(self):
33
38
  obs, state = super().reset()
34
- self.current_potential = self.compute_potential()
39
+ self._current_potential = self.compute_potential()
35
40
  return self.add_extras(obs), state
36
41
 
37
42
  def step(self, actions):
38
- phi_t = self.current_potential
43
+ prev_potential = self._current_potential
39
44
  step = super().step(actions)
40
45
 
41
- self.current_potential = self.compute_potential()
42
- shaped_reward = self.gamma * self.current_potential - phi_t
46
+ self._current_potential = self.compute_potential()
47
+ shaped_reward = self.gamma * self._current_potential - prev_potential
43
48
  step.obs = self.add_extras(step.obs)
44
49
  step.reward += shaped_reward
45
50
  return step
46
51
 
47
52
  @abstractmethod
48
- def compute_potential(self) -> float:
53
+ def compute_potential(self) -> float | npt.NDArray[np.float32]:
49
54
  """Compute the potential of the current state of the environment."""
@@ -1,8 +1,11 @@
1
1
  from marlenv import Observation, Transition, DiscreteMockEnv, DiscreteMOMockEnv, Builder, State, Episode, MARLEnv, DiscreteSpace
2
2
  import numpy as np
3
-
3
+ import pytest
4
+ from importlib.util import find_spec
4
5
  from .utils import generate_episode
5
6
 
7
+ HAS_PYTORCH = find_spec("torch") is not None
8
+
6
9
 
7
10
  def test_obs_eq():
8
11
  obs1 = Observation(
@@ -422,3 +425,33 @@ def test_wrong_extras_meanings_length():
422
425
  assert False, "This should raise a ValueError because the length of extras_meanings is different from the actual number of extras"
423
426
  except ValueError:
424
427
  pass
428
+
429
+
430
+ @pytest.mark.skipif(not HAS_PYTORCH, reason="torch is not installed")
431
+ def test_observation_as_tensor():
432
+ import torch
433
+
434
+ env = DiscreteMockEnv(4)
435
+ obs = env.reset()[0]
436
+ data, extras = obs.as_tensors()
437
+ assert isinstance(data, torch.Tensor)
438
+ assert data.shape == (1, env.n_agents, *env.observation_shape)
439
+ assert data.dtype == torch.float32
440
+ assert isinstance(extras, torch.Tensor)
441
+ assert extras.shape == (1, env.n_agents, *env.extras_shape)
442
+ assert extras.dtype == torch.float32
443
+
444
+
445
+ @pytest.mark.skipif(not HAS_PYTORCH, reason="torch is not installed")
446
+ def test_state_as_tensor():
447
+ import torch
448
+
449
+ env = DiscreteMockEnv(4)
450
+ state = env.reset()[1]
451
+ data, extras = state.as_tensors()
452
+ assert isinstance(data, torch.Tensor)
453
+ assert data.shape == (1, *env.state_shape)
454
+ assert data.dtype == torch.float32
455
+ assert isinstance(extras, torch.Tensor)
456
+ assert extras.shape == (1, *env.state_extra_shape)
457
+ assert extras.dtype == torch.float32