multi-agent-rlenv 3.5.2__py3-none-any.whl → 3.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marlenv/__init__.py CHANGED
@@ -62,7 +62,7 @@ print(env.extras_shape) # (1, )
62
62
  If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
63
63
  """
64
64
 
65
- __version__ = "3.5.2"
65
+ __version__ = "3.5.5"
66
66
 
67
67
  from . import models
68
68
  from .models import (
marlenv/models/episode.py CHANGED
@@ -2,30 +2,32 @@ from dataclasses import dataclass
2
2
  from functools import cached_property
3
3
  from typing import Any, Callable, Optional, Sequence, overload
4
4
 
5
+ import cv2
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
- import cv2
8
8
 
9
+ from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
10
+ from marlenv.utils import CachedPropertyInvalidator
11
+
12
+ from .env import MARLEnv
9
13
  from .observation import Observation
10
14
  from .state import State
11
15
  from .step import Step
12
16
  from .transition import Transition
13
- from .env import MARLEnv
14
- from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
15
17
 
16
18
 
17
19
  @dataclass
18
- class Episode:
20
+ class Episode(CachedPropertyInvalidator):
19
21
  """Episode model made of observations, actions, rewards, ..."""
20
22
 
21
23
  all_observations: list[npt.NDArray[np.float32]]
22
24
  all_extras: list[npt.NDArray[np.float32]]
23
25
  actions: list[npt.NDArray]
24
26
  rewards: list[npt.NDArray[np.float32]]
25
- all_available_actions: list[npt.NDArray[np.bool_]]
27
+ all_available_actions: list[npt.NDArray[np.bool]]
26
28
  all_states: list[npt.NDArray[np.float32]]
27
29
  all_states_extras: list[npt.NDArray[np.float32]]
28
- metrics: dict[str, float]
30
+ metrics: dict[str, Any]
29
31
  episode_len: int
30
32
  other: dict[str, list[Any]]
31
33
  is_done: bool = False
@@ -33,7 +35,7 @@ class Episode:
33
35
  """Whether the episode did reach a terminal state (different from truncated)"""
34
36
 
35
37
  @staticmethod
36
- def new(obs: Observation, state: State, metrics: Optional[dict[str, float]] = None) -> "Episode":
38
+ def new(obs: Observation, state: State, metrics: Optional[dict[str, Any]] = None) -> "Episode":
37
39
  if metrics is None:
38
40
  metrics = {}
39
41
  return Episode(
@@ -153,12 +155,12 @@ class Episode:
153
155
  """Get the next extra features"""
154
156
  return self.all_extras[1:]
155
157
 
156
- @cached_property
158
+ @property
157
159
  def n_agents(self):
158
160
  """The number of agents in the episode"""
159
161
  return self.all_extras[0].shape[0]
160
162
 
161
- @cached_property
163
+ @property
162
164
  def n_actions(self):
163
165
  """The number of actions"""
164
166
  return len(self.all_available_actions[0][0])
@@ -267,7 +269,7 @@ class Episode:
267
269
  def __len__(self):
268
270
  return self.episode_len
269
271
 
270
- @cached_property
272
+ @property
271
273
  def score(self) -> list[float]:
272
274
  """The episode score (sum of all rewards across all objectives)"""
273
275
  score = []
@@ -363,51 +365,6 @@ class Episode:
363
365
  for i, s in enumerate(scores):
364
366
  self.metrics[f"score-{i}"] = float(s)
365
367
 
366
- # def add_data(
367
- # self,
368
- # new_obs: Observation,
369
- # new_state: State,
370
- # action: A,
371
- # reward: np.ndarray,
372
- # done: bool,
373
- # truncated: bool,
374
- # info: dict[str, Any],
375
- # **kwargs,
376
- # ):
377
- # """Add a new transition to the episode"""
378
- # self.episode_len += 1
379
- # self.all_observations.append(new_obs.data)
380
- # self.all_extras.append(new_obs.extras)
381
- # self.all_available_actions.append(new_obs.available_actions)
382
- # self.all_states.append(new_state.data)
383
- # self.all_states_extras.append(new_state.extras)
384
- # match action:
385
- # case np.ndarray() as action:
386
- # self.actions.append(action)
387
- # case other:
388
- # self.actions.append(np.array(other))
389
- # self.rewards.append(reward)
390
- # for key, value in kwargs.items():
391
- # current = self.other.get(key, [])
392
- # current.append(value)
393
- # self.other[key] = current
394
-
395
- # if done:
396
- # # Only set the truncated flag if the episode is not done (both could happen with a time limit)
397
- # self.is_truncated = truncated
398
- # self.is_done = done
399
- # # Add metrics that can be plotted
400
- # for key, value in info.items():
401
- # if isinstance(value, bool):
402
- # value = int(value)
403
- # self.metrics[key] = value
404
- # self.metrics["episode_len"] = self.episode_len
405
-
406
- # rewards = np.array(self.rewards)
407
- # scores = np.sum(rewards, axis=0)
408
- # for i, s in enumerate(scores):
409
- # self.metrics[f"score-{i}"] = float(s)
410
-
411
- def add_metrics(self, metrics: dict[str, float]):
368
+ def add_metrics(self, metrics: dict[str, Any]):
412
369
  """Add metrics to the episode"""
413
370
  self.metrics.update(metrics)
@@ -87,3 +87,13 @@ class Observation:
87
87
  if not np.array_equal(self.data, other.data):
88
88
  return False
89
89
  return np.array_equal(self.extras, other.extras) and np.array_equal(self.available_actions, other.available_actions)
90
+
91
+ def as_tensors(self, device=None):
92
+ """
93
+ Convert the observation to a tuple of tensors of shape (1, n_agents, <dim>).
94
+ """
95
+ import torch
96
+
97
+ data = torch.from_numpy(self.data).unsqueeze(0).to(device, non_blocking=True)
98
+ extras = torch.from_numpy(self.extras).unsqueeze(0).to(device, non_blocking=True)
99
+ return data, extras
marlenv/models/state.py CHANGED
@@ -52,3 +52,11 @@ class State(Generic[StateType]):
52
52
  if not np.array_equal(self.extras, value.extras):
53
53
  return False
54
54
  return True
55
+
56
+ def as_tensors(self, device=None):
57
+ """Convert the state to a tuple of tensors of shape (1, <dim>)."""
58
+ import torch
59
+
60
+ data = torch.from_numpy(self.data).unsqueeze(0).to(device, non_blocking=True)
61
+ extras = torch.from_numpy(self.extras).unsqueeze(0).to(device, non_blocking=True)
62
+ return data, extras
marlenv/utils/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- from .schedule import Schedule, MultiSchedule, RoundedSchedule, LinearSchedule, ExpSchedule
2
-
1
+ from .cached_property_collector import CachedPropertyCollector, CachedPropertyInvalidator
2
+ from .schedule import ExpSchedule, LinearSchedule, MultiSchedule, RoundedSchedule, Schedule
3
3
 
4
4
  __all__ = [
5
5
  "Schedule",
@@ -7,4 +7,6 @@ __all__ = [
7
7
  "ExpSchedule",
8
8
  "MultiSchedule",
9
9
  "RoundedSchedule",
10
+ "CachedPropertyCollector",
11
+ "CachedPropertyInvalidator",
10
12
  ]
@@ -0,0 +1,17 @@
1
+ from functools import cached_property
2
+
3
+
4
+ class CachedPropertyCollector(type):
5
+ def __init__(cls, name: str, bases: tuple, namespace: dict):
6
+ super().__init__(name, bases, namespace)
7
+ cls.CACHED_PROPERTY_NAMES = [key for key, value in namespace.items() if isinstance(value, cached_property)]
8
+
9
+
10
+ class CachedPropertyInvalidator(metaclass=CachedPropertyCollector):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def invalidate_cached_properties(self):
15
+ for key in self.__class__.CACHED_PROPERTY_NAMES:
16
+ if hasattr(self, key):
17
+ delattr(self, key)
@@ -11,6 +11,7 @@ from .centralised import Centralized
11
11
  from .available_actions_mask import AvailableActionsMask
12
12
  from .delayed_rewards import DelayedReward
13
13
  from .potential_shaping import PotentialShaping
14
+ from .action_randomizer import ActionRandomizer
14
15
 
15
16
  __all__ = [
16
17
  "RLEnvWrapper",
@@ -28,4 +29,5 @@ __all__ = [
28
29
  "Centralized",
29
30
  "DelayedReward",
30
31
  "PotentialShaping",
32
+ "ActionRandomizer",
31
33
  ]
@@ -0,0 +1,17 @@
1
+ from .rlenv_wrapper import RLEnvWrapper, AS, MARLEnv
2
+ import numpy as np
3
+
4
+
5
+ class ActionRandomizer(RLEnvWrapper[AS]):
6
+ def __init__(self, env: MARLEnv[AS], p: float):
7
+ super().__init__(env)
8
+ self.p = p
9
+
10
+ def step(self, action):
11
+ if np.random.rand() < self.p:
12
+ action = self.action_space.sample()
13
+ return super().step(action)
14
+
15
+ def seed(self, seed_value: int):
16
+ np.random.seed(seed_value)
17
+ super().seed(seed_value)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.5.2
3
+ Version: 3.5.5
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
@@ -19,6 +19,7 @@ Requires-Dist: pymunk>=6.0; extra == 'all'
19
19
  Requires-Dist: pysc2; extra == 'all'
20
20
  Requires-Dist: scipy>=1.10; extra == 'all'
21
21
  Requires-Dist: smac; extra == 'all'
22
+ Requires-Dist: torch>=2.0; extra == 'all'
22
23
  Provides-Extra: gym
23
24
  Requires-Dist: gymnasium>=0.29.1; extra == 'gym'
24
25
  Provides-Extra: overcooked
@@ -31,6 +32,8 @@ Requires-Dist: scipy>=1.10; extra == 'pettingzoo'
31
32
  Provides-Extra: smac
32
33
  Requires-Dist: pysc2; extra == 'smac'
33
34
  Requires-Dist: smac; extra == 'smac'
35
+ Provides-Extra: torch
36
+ Requires-Dist: torch>=2.0; extra == 'torch'
34
37
  Description-Content-Type: text/markdown
35
38
 
36
39
  # `marlenv` - A unified framework for muti-agent reinforcement learning
@@ -1,4 +1,4 @@
1
- marlenv/__init__.py,sha256=UoZATsYMuKlnHyYdIRX7eQ6mGcmMww-tqX3uCyWVqRA,3656
1
+ marlenv/__init__.py,sha256=bX76JknjwfVJ6IOKql_y4rIqvvx9raepD7u2lB9CgGo,3656
2
2
  marlenv/env_builder.py,sha256=RJoHJLYAUE1ausAoJiRC3fUxyxpH1WRJf7Sdm2ml-uk,5517
3
3
  marlenv/env_pool.py,sha256=nCEBkGQU62fcvCAANyAqY8gCFjYlVnSCg-V3Fhx00yc,933
4
4
  marlenv/exceptions.py,sha256=gJUC_2rVAvOfK_ypVFc7Myh-pIfSU3To38VBVS_0rZA,1179
@@ -12,15 +12,17 @@ marlenv/adapters/pymarl_adapter.py,sha256=2s7EY31s1hrml3q-BBaXo_eDMXTjkebozZPvzs
12
12
  marlenv/adapters/smac_adapter.py,sha256=8uWC7YKsaSXeTS8AUhpGOKvrWMbVEQT2-pml5BaFUB0,8343
13
13
  marlenv/models/__init__.py,sha256=uihmRs71Gg5z7Bvau_xtaQVg7xEtX8sTzi74bIHL5P0,443
14
14
  marlenv/models/env.py,sha256=BG1iVHxGD_p827mF0ewyOBn6wU2gtFsHLW1b4UtW-V0,7841
15
- marlenv/models/episode.py,sha256=IKPLuDVlz85Be6zrC21gyautjqRkEApS4fgRqQR52s0,15190
16
- marlenv/models/observation.py,sha256=kAmh1hIoC2TGrZlGVzV0y4TXXCSrI7gcmG0raeoncYk,3153
15
+ marlenv/models/episode.py,sha256=zsyxsW4LIioPKyY4DZKn64A31e5ZvlwOf3HIGuRUzhs,13531
16
+ marlenv/models/observation.py,sha256=RhvKvmys4bu3UwwVsvu7fJ7TMKt2QkKnBD1e0hw2r7s,3528
17
17
  marlenv/models/spaces.py,sha256=v7jnhPfj7vq7DFFJpSbQEYe4NGLLlj_bj2pzvvSBX7Y,7777
18
- marlenv/models/state.py,sha256=958PXTHadi3gtRnhGgcGtqBnF44R11kdcx62NN2gwxA,1717
18
+ marlenv/models/state.py,sha256=LbP--JxBzRwMFpEAaZyxCX13xKQ27xPE2fabohaq9YI,2058
19
19
  marlenv/models/step.py,sha256=00PhD_ccdCIYAY1SVJdJU91weU0Y_tNIJwK16TN_53I,3056
20
20
  marlenv/models/transition.py,sha256=UkJVRNxZoyRkjE7YmKtUf_4xA7cOEh20O60dTldbvys,5070
21
- marlenv/utils/__init__.py,sha256=C3qhvkVwctBP8mG3G5nkAZ5DKfErVRkdbHo7oeWVsM0,209
21
+ marlenv/utils/__init__.py,sha256=36pNw0r4V3xsqPZ5ljM29o96dfPAFq8WvMwggyv41fI,362
22
+ marlenv/utils/cached_property_collector.py,sha256=IOjbr61f0DqLhcidXKrl7MhN1BOEGiTzCANIKQCxaF0,600
22
23
  marlenv/utils/schedule.py,sha256=slhtpQiBHSUNyPmSkKb2yBgiHJqPhoPxa33GxvyV8Jc,8565
23
- marlenv/wrappers/__init__.py,sha256=uV00m0jysZBgOW-TvRekis-gsAKPeR51P3HsuRZKxG8,880
24
+ marlenv/wrappers/__init__.py,sha256=Z4_M-mxRNKQeu52tkmQ4B2m3-zrsmjfXXL5NsWQ4vu4,952
25
+ marlenv/wrappers/action_randomizer.py,sha256=A1kejqGOTA0sc_RQL0EOd6sMSbcIdiV5zlscjKUlzdY,474
24
26
  marlenv/wrappers/agent_id_wrapper.py,sha256=9qHV3LMQ4AjcDCSuvQhz5h9hUf7Xtrdi2sIxmNZk5NA,1126
25
27
  marlenv/wrappers/available_actions_mask.py,sha256=OMyt2KntsR8JA2RuRgvwdzqzPe-_H-KKkbUUJfe_mks,1404
26
28
  marlenv/wrappers/available_actions_wrapper.py,sha256=_HRl9zsjJgSrLgVuT-BjpnnfrfM8ic6wBUWlg67uCx4,926
@@ -34,7 +36,7 @@ marlenv/wrappers/potential_shaping.py,sha256=T_QvnmWReCgpyoInxRw2UXbmdvcBD5U-vV1
34
36
  marlenv/wrappers/rlenv_wrapper.py,sha256=S6G1VjFklTEzU6bj0AXrTDXnsTQJARq8VB4uUH6AXe4,2993
35
37
  marlenv/wrappers/time_limit.py,sha256=GxbxcbfFyuVg14ylQU2C_cjmV9q4uDAt5wepfgX_PyM,3976
36
38
  marlenv/wrappers/video_recorder.py,sha256=ucBQSNRPqDr-2mYxrTCqlrWcxSWtSJ7XlRC9-LdukBM,2535
37
- multi_agent_rlenv-3.5.2.dist-info/METADATA,sha256=QjQkN0ZJsbaa-GyP7fAs4JFSTJkEUBLrIV0zCGPUvrc,4897
38
- multi_agent_rlenv-3.5.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
39
- multi_agent_rlenv-3.5.2.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
40
- multi_agent_rlenv-3.5.2.dist-info/RECORD,,
39
+ multi_agent_rlenv-3.5.5.dist-info/METADATA,sha256=WKf56Bb7PqZFrw2B6Sx8zulM-h7aZRqXMTvYHrSxEtQ,5005
40
+ multi_agent_rlenv-3.5.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
41
+ multi_agent_rlenv-3.5.5.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
42
+ multi_agent_rlenv-3.5.5.dist-info/RECORD,,