multi-agent-rlenv 3.5.2__py3-none-any.whl → 3.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- marlenv/__init__.py +1 -1
- marlenv/models/episode.py +13 -56
- marlenv/models/observation.py +10 -0
- marlenv/models/state.py +8 -0
- marlenv/utils/__init__.py +4 -2
- marlenv/utils/cached_property_collector.py +17 -0
- marlenv/wrappers/__init__.py +2 -0
- marlenv/wrappers/action_randomizer.py +17 -0
- {multi_agent_rlenv-3.5.2.dist-info → multi_agent_rlenv-3.5.5.dist-info}/METADATA +4 -1
- {multi_agent_rlenv-3.5.2.dist-info → multi_agent_rlenv-3.5.5.dist-info}/RECORD +12 -10
- {multi_agent_rlenv-3.5.2.dist-info → multi_agent_rlenv-3.5.5.dist-info}/WHEEL +0 -0
- {multi_agent_rlenv-3.5.2.dist-info → multi_agent_rlenv-3.5.5.dist-info}/licenses/LICENSE +0 -0
marlenv/__init__.py
CHANGED
|
@@ -62,7 +62,7 @@ print(env.extras_shape) # (1, )
|
|
|
62
62
|
If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
|
|
63
63
|
"""
|
|
64
64
|
|
|
65
|
-
__version__ = "3.5.
|
|
65
|
+
__version__ = "3.5.5"
|
|
66
66
|
|
|
67
67
|
from . import models
|
|
68
68
|
from .models import (
|
marlenv/models/episode.py
CHANGED
|
@@ -2,30 +2,32 @@ from dataclasses import dataclass
|
|
|
2
2
|
from functools import cached_property
|
|
3
3
|
from typing import Any, Callable, Optional, Sequence, overload
|
|
4
4
|
|
|
5
|
+
import cv2
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
7
|
-
import cv2
|
|
8
8
|
|
|
9
|
+
from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
|
|
10
|
+
from marlenv.utils import CachedPropertyInvalidator
|
|
11
|
+
|
|
12
|
+
from .env import MARLEnv
|
|
9
13
|
from .observation import Observation
|
|
10
14
|
from .state import State
|
|
11
15
|
from .step import Step
|
|
12
16
|
from .transition import Transition
|
|
13
|
-
from .env import MARLEnv
|
|
14
|
-
from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
@dataclass
|
|
18
|
-
class Episode:
|
|
20
|
+
class Episode(CachedPropertyInvalidator):
|
|
19
21
|
"""Episode model made of observations, actions, rewards, ..."""
|
|
20
22
|
|
|
21
23
|
all_observations: list[npt.NDArray[np.float32]]
|
|
22
24
|
all_extras: list[npt.NDArray[np.float32]]
|
|
23
25
|
actions: list[npt.NDArray]
|
|
24
26
|
rewards: list[npt.NDArray[np.float32]]
|
|
25
|
-
all_available_actions: list[npt.NDArray[np.
|
|
27
|
+
all_available_actions: list[npt.NDArray[np.bool]]
|
|
26
28
|
all_states: list[npt.NDArray[np.float32]]
|
|
27
29
|
all_states_extras: list[npt.NDArray[np.float32]]
|
|
28
|
-
metrics: dict[str,
|
|
30
|
+
metrics: dict[str, Any]
|
|
29
31
|
episode_len: int
|
|
30
32
|
other: dict[str, list[Any]]
|
|
31
33
|
is_done: bool = False
|
|
@@ -33,7 +35,7 @@ class Episode:
|
|
|
33
35
|
"""Whether the episode did reach a terminal state (different from truncated)"""
|
|
34
36
|
|
|
35
37
|
@staticmethod
|
|
36
|
-
def new(obs: Observation, state: State, metrics: Optional[dict[str,
|
|
38
|
+
def new(obs: Observation, state: State, metrics: Optional[dict[str, Any]] = None) -> "Episode":
|
|
37
39
|
if metrics is None:
|
|
38
40
|
metrics = {}
|
|
39
41
|
return Episode(
|
|
@@ -153,12 +155,12 @@ class Episode:
|
|
|
153
155
|
"""Get the next extra features"""
|
|
154
156
|
return self.all_extras[1:]
|
|
155
157
|
|
|
156
|
-
@
|
|
158
|
+
@property
|
|
157
159
|
def n_agents(self):
|
|
158
160
|
"""The number of agents in the episode"""
|
|
159
161
|
return self.all_extras[0].shape[0]
|
|
160
162
|
|
|
161
|
-
@
|
|
163
|
+
@property
|
|
162
164
|
def n_actions(self):
|
|
163
165
|
"""The number of actions"""
|
|
164
166
|
return len(self.all_available_actions[0][0])
|
|
@@ -267,7 +269,7 @@ class Episode:
|
|
|
267
269
|
def __len__(self):
|
|
268
270
|
return self.episode_len
|
|
269
271
|
|
|
270
|
-
@
|
|
272
|
+
@property
|
|
271
273
|
def score(self) -> list[float]:
|
|
272
274
|
"""The episode score (sum of all rewards across all objectives)"""
|
|
273
275
|
score = []
|
|
@@ -363,51 +365,6 @@ class Episode:
|
|
|
363
365
|
for i, s in enumerate(scores):
|
|
364
366
|
self.metrics[f"score-{i}"] = float(s)
|
|
365
367
|
|
|
366
|
-
|
|
367
|
-
# self,
|
|
368
|
-
# new_obs: Observation,
|
|
369
|
-
# new_state: State,
|
|
370
|
-
# action: A,
|
|
371
|
-
# reward: np.ndarray,
|
|
372
|
-
# done: bool,
|
|
373
|
-
# truncated: bool,
|
|
374
|
-
# info: dict[str, Any],
|
|
375
|
-
# **kwargs,
|
|
376
|
-
# ):
|
|
377
|
-
# """Add a new transition to the episode"""
|
|
378
|
-
# self.episode_len += 1
|
|
379
|
-
# self.all_observations.append(new_obs.data)
|
|
380
|
-
# self.all_extras.append(new_obs.extras)
|
|
381
|
-
# self.all_available_actions.append(new_obs.available_actions)
|
|
382
|
-
# self.all_states.append(new_state.data)
|
|
383
|
-
# self.all_states_extras.append(new_state.extras)
|
|
384
|
-
# match action:
|
|
385
|
-
# case np.ndarray() as action:
|
|
386
|
-
# self.actions.append(action)
|
|
387
|
-
# case other:
|
|
388
|
-
# self.actions.append(np.array(other))
|
|
389
|
-
# self.rewards.append(reward)
|
|
390
|
-
# for key, value in kwargs.items():
|
|
391
|
-
# current = self.other.get(key, [])
|
|
392
|
-
# current.append(value)
|
|
393
|
-
# self.other[key] = current
|
|
394
|
-
|
|
395
|
-
# if done:
|
|
396
|
-
# # Only set the truncated flag if the episode is not done (both could happen with a time limit)
|
|
397
|
-
# self.is_truncated = truncated
|
|
398
|
-
# self.is_done = done
|
|
399
|
-
# # Add metrics that can be plotted
|
|
400
|
-
# for key, value in info.items():
|
|
401
|
-
# if isinstance(value, bool):
|
|
402
|
-
# value = int(value)
|
|
403
|
-
# self.metrics[key] = value
|
|
404
|
-
# self.metrics["episode_len"] = self.episode_len
|
|
405
|
-
|
|
406
|
-
# rewards = np.array(self.rewards)
|
|
407
|
-
# scores = np.sum(rewards, axis=0)
|
|
408
|
-
# for i, s in enumerate(scores):
|
|
409
|
-
# self.metrics[f"score-{i}"] = float(s)
|
|
410
|
-
|
|
411
|
-
def add_metrics(self, metrics: dict[str, float]):
|
|
368
|
+
def add_metrics(self, metrics: dict[str, Any]):
|
|
412
369
|
"""Add metrics to the episode"""
|
|
413
370
|
self.metrics.update(metrics)
|
marlenv/models/observation.py
CHANGED
|
@@ -87,3 +87,13 @@ class Observation:
|
|
|
87
87
|
if not np.array_equal(self.data, other.data):
|
|
88
88
|
return False
|
|
89
89
|
return np.array_equal(self.extras, other.extras) and np.array_equal(self.available_actions, other.available_actions)
|
|
90
|
+
|
|
91
|
+
def as_tensors(self, device=None):
|
|
92
|
+
"""
|
|
93
|
+
Convert the observation to a tuple of tensors of shape (1, n_agents, <dim>).
|
|
94
|
+
"""
|
|
95
|
+
import torch
|
|
96
|
+
|
|
97
|
+
data = torch.from_numpy(self.data).unsqueeze(0).to(device, non_blocking=True)
|
|
98
|
+
extras = torch.from_numpy(self.extras).unsqueeze(0).to(device, non_blocking=True)
|
|
99
|
+
return data, extras
|
marlenv/models/state.py
CHANGED
|
@@ -52,3 +52,11 @@ class State(Generic[StateType]):
|
|
|
52
52
|
if not np.array_equal(self.extras, value.extras):
|
|
53
53
|
return False
|
|
54
54
|
return True
|
|
55
|
+
|
|
56
|
+
def as_tensors(self, device=None):
|
|
57
|
+
"""Convert the state to a tuple of tensors of shape (1, <dim>)."""
|
|
58
|
+
import torch
|
|
59
|
+
|
|
60
|
+
data = torch.from_numpy(self.data).unsqueeze(0).to(device, non_blocking=True)
|
|
61
|
+
extras = torch.from_numpy(self.extras).unsqueeze(0).to(device, non_blocking=True)
|
|
62
|
+
return data, extras
|
marlenv/utils/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from .
|
|
2
|
-
|
|
1
|
+
from .cached_property_collector import CachedPropertyCollector, CachedPropertyInvalidator
|
|
2
|
+
from .schedule import ExpSchedule, LinearSchedule, MultiSchedule, RoundedSchedule, Schedule
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
5
|
"Schedule",
|
|
@@ -7,4 +7,6 @@ __all__ = [
|
|
|
7
7
|
"ExpSchedule",
|
|
8
8
|
"MultiSchedule",
|
|
9
9
|
"RoundedSchedule",
|
|
10
|
+
"CachedPropertyCollector",
|
|
11
|
+
"CachedPropertyInvalidator",
|
|
10
12
|
]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CachedPropertyCollector(type):
|
|
5
|
+
def __init__(cls, name: str, bases: tuple, namespace: dict):
|
|
6
|
+
super().__init__(name, bases, namespace)
|
|
7
|
+
cls.CACHED_PROPERTY_NAMES = [key for key, value in namespace.items() if isinstance(value, cached_property)]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CachedPropertyInvalidator(metaclass=CachedPropertyCollector):
|
|
11
|
+
def __init__(self):
|
|
12
|
+
super().__init__()
|
|
13
|
+
|
|
14
|
+
def invalidate_cached_properties(self):
|
|
15
|
+
for key in self.__class__.CACHED_PROPERTY_NAMES:
|
|
16
|
+
if hasattr(self, key):
|
|
17
|
+
delattr(self, key)
|
marlenv/wrappers/__init__.py
CHANGED
|
@@ -11,6 +11,7 @@ from .centralised import Centralized
|
|
|
11
11
|
from .available_actions_mask import AvailableActionsMask
|
|
12
12
|
from .delayed_rewards import DelayedReward
|
|
13
13
|
from .potential_shaping import PotentialShaping
|
|
14
|
+
from .action_randomizer import ActionRandomizer
|
|
14
15
|
|
|
15
16
|
__all__ = [
|
|
16
17
|
"RLEnvWrapper",
|
|
@@ -28,4 +29,5 @@ __all__ = [
|
|
|
28
29
|
"Centralized",
|
|
29
30
|
"DelayedReward",
|
|
30
31
|
"PotentialShaping",
|
|
32
|
+
"ActionRandomizer",
|
|
31
33
|
]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .rlenv_wrapper import RLEnvWrapper, AS, MARLEnv
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ActionRandomizer(RLEnvWrapper[AS]):
|
|
6
|
+
def __init__(self, env: MARLEnv[AS], p: float):
|
|
7
|
+
super().__init__(env)
|
|
8
|
+
self.p = p
|
|
9
|
+
|
|
10
|
+
def step(self, action):
|
|
11
|
+
if np.random.rand() < self.p:
|
|
12
|
+
action = self.action_space.sample()
|
|
13
|
+
return super().step(action)
|
|
14
|
+
|
|
15
|
+
def seed(self, seed_value: int):
|
|
16
|
+
np.random.seed(seed_value)
|
|
17
|
+
super().seed(seed_value)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: multi-agent-rlenv
|
|
3
|
-
Version: 3.5.
|
|
3
|
+
Version: 3.5.5
|
|
4
4
|
Summary: A strongly typed Multi-Agent Reinforcement Learning framework
|
|
5
5
|
Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
|
|
6
6
|
Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
|
|
@@ -19,6 +19,7 @@ Requires-Dist: pymunk>=6.0; extra == 'all'
|
|
|
19
19
|
Requires-Dist: pysc2; extra == 'all'
|
|
20
20
|
Requires-Dist: scipy>=1.10; extra == 'all'
|
|
21
21
|
Requires-Dist: smac; extra == 'all'
|
|
22
|
+
Requires-Dist: torch>=2.0; extra == 'all'
|
|
22
23
|
Provides-Extra: gym
|
|
23
24
|
Requires-Dist: gymnasium>=0.29.1; extra == 'gym'
|
|
24
25
|
Provides-Extra: overcooked
|
|
@@ -31,6 +32,8 @@ Requires-Dist: scipy>=1.10; extra == 'pettingzoo'
|
|
|
31
32
|
Provides-Extra: smac
|
|
32
33
|
Requires-Dist: pysc2; extra == 'smac'
|
|
33
34
|
Requires-Dist: smac; extra == 'smac'
|
|
35
|
+
Provides-Extra: torch
|
|
36
|
+
Requires-Dist: torch>=2.0; extra == 'torch'
|
|
34
37
|
Description-Content-Type: text/markdown
|
|
35
38
|
|
|
36
39
|
# `marlenv` - A unified framework for muti-agent reinforcement learning
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
marlenv/__init__.py,sha256=
|
|
1
|
+
marlenv/__init__.py,sha256=bX76JknjwfVJ6IOKql_y4rIqvvx9raepD7u2lB9CgGo,3656
|
|
2
2
|
marlenv/env_builder.py,sha256=RJoHJLYAUE1ausAoJiRC3fUxyxpH1WRJf7Sdm2ml-uk,5517
|
|
3
3
|
marlenv/env_pool.py,sha256=nCEBkGQU62fcvCAANyAqY8gCFjYlVnSCg-V3Fhx00yc,933
|
|
4
4
|
marlenv/exceptions.py,sha256=gJUC_2rVAvOfK_ypVFc7Myh-pIfSU3To38VBVS_0rZA,1179
|
|
@@ -12,15 +12,17 @@ marlenv/adapters/pymarl_adapter.py,sha256=2s7EY31s1hrml3q-BBaXo_eDMXTjkebozZPvzs
|
|
|
12
12
|
marlenv/adapters/smac_adapter.py,sha256=8uWC7YKsaSXeTS8AUhpGOKvrWMbVEQT2-pml5BaFUB0,8343
|
|
13
13
|
marlenv/models/__init__.py,sha256=uihmRs71Gg5z7Bvau_xtaQVg7xEtX8sTzi74bIHL5P0,443
|
|
14
14
|
marlenv/models/env.py,sha256=BG1iVHxGD_p827mF0ewyOBn6wU2gtFsHLW1b4UtW-V0,7841
|
|
15
|
-
marlenv/models/episode.py,sha256=
|
|
16
|
-
marlenv/models/observation.py,sha256=
|
|
15
|
+
marlenv/models/episode.py,sha256=zsyxsW4LIioPKyY4DZKn64A31e5ZvlwOf3HIGuRUzhs,13531
|
|
16
|
+
marlenv/models/observation.py,sha256=RhvKvmys4bu3UwwVsvu7fJ7TMKt2QkKnBD1e0hw2r7s,3528
|
|
17
17
|
marlenv/models/spaces.py,sha256=v7jnhPfj7vq7DFFJpSbQEYe4NGLLlj_bj2pzvvSBX7Y,7777
|
|
18
|
-
marlenv/models/state.py,sha256=
|
|
18
|
+
marlenv/models/state.py,sha256=LbP--JxBzRwMFpEAaZyxCX13xKQ27xPE2fabohaq9YI,2058
|
|
19
19
|
marlenv/models/step.py,sha256=00PhD_ccdCIYAY1SVJdJU91weU0Y_tNIJwK16TN_53I,3056
|
|
20
20
|
marlenv/models/transition.py,sha256=UkJVRNxZoyRkjE7YmKtUf_4xA7cOEh20O60dTldbvys,5070
|
|
21
|
-
marlenv/utils/__init__.py,sha256=
|
|
21
|
+
marlenv/utils/__init__.py,sha256=36pNw0r4V3xsqPZ5ljM29o96dfPAFq8WvMwggyv41fI,362
|
|
22
|
+
marlenv/utils/cached_property_collector.py,sha256=IOjbr61f0DqLhcidXKrl7MhN1BOEGiTzCANIKQCxaF0,600
|
|
22
23
|
marlenv/utils/schedule.py,sha256=slhtpQiBHSUNyPmSkKb2yBgiHJqPhoPxa33GxvyV8Jc,8565
|
|
23
|
-
marlenv/wrappers/__init__.py,sha256=
|
|
24
|
+
marlenv/wrappers/__init__.py,sha256=Z4_M-mxRNKQeu52tkmQ4B2m3-zrsmjfXXL5NsWQ4vu4,952
|
|
25
|
+
marlenv/wrappers/action_randomizer.py,sha256=A1kejqGOTA0sc_RQL0EOd6sMSbcIdiV5zlscjKUlzdY,474
|
|
24
26
|
marlenv/wrappers/agent_id_wrapper.py,sha256=9qHV3LMQ4AjcDCSuvQhz5h9hUf7Xtrdi2sIxmNZk5NA,1126
|
|
25
27
|
marlenv/wrappers/available_actions_mask.py,sha256=OMyt2KntsR8JA2RuRgvwdzqzPe-_H-KKkbUUJfe_mks,1404
|
|
26
28
|
marlenv/wrappers/available_actions_wrapper.py,sha256=_HRl9zsjJgSrLgVuT-BjpnnfrfM8ic6wBUWlg67uCx4,926
|
|
@@ -34,7 +36,7 @@ marlenv/wrappers/potential_shaping.py,sha256=T_QvnmWReCgpyoInxRw2UXbmdvcBD5U-vV1
|
|
|
34
36
|
marlenv/wrappers/rlenv_wrapper.py,sha256=S6G1VjFklTEzU6bj0AXrTDXnsTQJARq8VB4uUH6AXe4,2993
|
|
35
37
|
marlenv/wrappers/time_limit.py,sha256=GxbxcbfFyuVg14ylQU2C_cjmV9q4uDAt5wepfgX_PyM,3976
|
|
36
38
|
marlenv/wrappers/video_recorder.py,sha256=ucBQSNRPqDr-2mYxrTCqlrWcxSWtSJ7XlRC9-LdukBM,2535
|
|
37
|
-
multi_agent_rlenv-3.5.
|
|
38
|
-
multi_agent_rlenv-3.5.
|
|
39
|
-
multi_agent_rlenv-3.5.
|
|
40
|
-
multi_agent_rlenv-3.5.
|
|
39
|
+
multi_agent_rlenv-3.5.5.dist-info/METADATA,sha256=WKf56Bb7PqZFrw2B6Sx8zulM-h7aZRqXMTvYHrSxEtQ,5005
|
|
40
|
+
multi_agent_rlenv-3.5.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
41
|
+
multi_agent_rlenv-3.5.5.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
|
|
42
|
+
multi_agent_rlenv-3.5.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|