mapox 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mapox-0.1.0/.gitignore +10 -0
- mapox-0.1.0/.python-version +1 -0
- mapox-0.1.0/PKG-INFO +17 -0
- mapox-0.1.0/README.md +0 -0
- mapox-0.1.0/mapox/__init__.py +30 -0
- mapox-0.1.0/mapox/client.py +53 -0
- mapox-0.1.0/mapox/config.py +92 -0
- mapox-0.1.0/mapox/environment.py +52 -0
- mapox-0.1.0/mapox/envs/__init__.py +0 -0
- mapox-0.1.0/mapox/envs/constance.py +75 -0
- mapox-0.1.0/mapox/envs/grid_return.py +389 -0
- mapox-0.1.0/mapox/envs/king_hill.py +536 -0
- mapox-0.1.0/mapox/envs/scouts.py +357 -0
- mapox-0.1.0/mapox/envs/traveling_salesman.py +296 -0
- mapox-0.1.0/mapox/map_generator.py +137 -0
- mapox-0.1.0/mapox/renderer.py +318 -0
- mapox-0.1.0/mapox/specs.py +19 -0
- mapox-0.1.0/mapox/timestep.py +21 -0
- mapox-0.1.0/mapox/utils/__init__.py +0 -0
- mapox-0.1.0/mapox/utils/encode_one_hot.py +21 -0
- mapox-0.1.0/mapox/utils/video_writter.py +50 -0
- mapox-0.1.0/mapox/wrappers/__init__.py +0 -0
- mapox-0.1.0/mapox/wrappers/multitask.py +94 -0
- mapox-0.1.0/mapox/wrappers/task_id_wrapper.py +23 -0
- mapox-0.1.0/mapox/wrappers/vector.py +68 -0
- mapox-0.1.0/pyproject.toml +32 -0
- mapox-0.1.0/uv.lock +497 -0
mapox-0.1.0/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.11
|
mapox-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mapox
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Multi-Agent Partially Observable gridworlds in JAX
|
|
5
|
+
Author-email: Gabriel Keith <gfk1995@gmail.com>
|
|
6
|
+
Keywords: gridworld,jax,multi-agent,pomdp,reinforcement-learning
|
|
7
|
+
Classifier: Development Status :: 3 - Alpha
|
|
8
|
+
Classifier: Intended Audience :: Science/Research
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
12
|
+
Requires-Python: >=3.11
|
|
13
|
+
Requires-Dist: einops>=0.8.1
|
|
14
|
+
Requires-Dist: jax>=0.7.2
|
|
15
|
+
Requires-Dist: pydantic>=2.12.5
|
|
16
|
+
Requires-Dist: pygame-ce>=2.5.6
|
|
17
|
+
Requires-Dist: python-ffmpeg>=2.0.12
|
mapox-0.1.0/README.md
ADDED
|
File without changes
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""MAPOX: Multi-Agent Partially Observable gridworlds in JAX"""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.0"
|
|
4
|
+
|
|
5
|
+
from mapox.timestep import TimeStep
|
|
6
|
+
from mapox.environment import Environment
|
|
7
|
+
from mapox.specs import ActionSpec, ObservationSpec
|
|
8
|
+
from mapox.wrappers.multitask import MultiTaskWrapper
|
|
9
|
+
from mapox.wrappers.vector import VectorWrapper
|
|
10
|
+
from mapox.client import GridworldClient
|
|
11
|
+
from mapox.utils.encode_one_hot import concat_one_hot
|
|
12
|
+
from mapox.config import create_env, create_client, ScoutsConfig, TravelingSalesmanConfig, KingHillConfig, MultiTaskConfig, EnvironmentConfig
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"TimeStep",
|
|
16
|
+
"ActionSpec",
|
|
17
|
+
"ObservationSpec",
|
|
18
|
+
"Environment",
|
|
19
|
+
"MultiTaskWrapper",
|
|
20
|
+
"VectorWrapper",
|
|
21
|
+
"GridworldClient",
|
|
22
|
+
"concat_one_hot",
|
|
23
|
+
"create_env",
|
|
24
|
+
"create_client",
|
|
25
|
+
"ScoutsConfig",
|
|
26
|
+
"TravelingSalesmanConfig",
|
|
27
|
+
"KingHillConfig",
|
|
28
|
+
"MultiTaskConfig",
|
|
29
|
+
"EnvironmentConfig"
|
|
30
|
+
]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import pygame
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
from mapox.renderer import GridworldRenderer, GridRenderState
|
|
5
|
+
from mapox.timestep import TimeStep
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EnvironmentClient[State](ABC):
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def render(self, state: State, timestep: TimeStep): ...
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def save_video(self): ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GridworldClient:
|
|
17
|
+
"""EnvironmentClient that renders via GridworldRenderer using per-env adapters."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self, env, screen_width: int = 960, screen_height: int = 960, fps: int = 10
|
|
21
|
+
):
|
|
22
|
+
assert hasattr(env, "get_render_state"), (
|
|
23
|
+
"Env must implement get_render_state(state)"
|
|
24
|
+
)
|
|
25
|
+
assert hasattr(env, "get_render_settings"), (
|
|
26
|
+
"Env must implement get_render_settings()"
|
|
27
|
+
)
|
|
28
|
+
self.env = env
|
|
29
|
+
self.renderer = GridworldRenderer(
|
|
30
|
+
screen_width=screen_width, screen_height=screen_height, fps=fps
|
|
31
|
+
)
|
|
32
|
+
self.renderer.set_env(env.get_render_settings())
|
|
33
|
+
|
|
34
|
+
def render(self, state, timestep):
|
|
35
|
+
rs: GridRenderState = self.env.get_render_state(state)
|
|
36
|
+
self.renderer.render(rs)
|
|
37
|
+
|
|
38
|
+
def render_pov(self, state, timestep):
|
|
39
|
+
"""Render only the focused agent's point-of-view, filling the screen."""
|
|
40
|
+
rs: GridRenderState = self.env.get_render_state(state)
|
|
41
|
+
self.renderer.render_agent_view(rs)
|
|
42
|
+
|
|
43
|
+
def handle_event(self, event: pygame.event.Event) -> bool:
|
|
44
|
+
return self.renderer.handle_event(event)
|
|
45
|
+
|
|
46
|
+
def record_frame(self):
|
|
47
|
+
self.renderer.record_frame()
|
|
48
|
+
|
|
49
|
+
def save_video(self, file_name: str):
|
|
50
|
+
self.renderer.save_video(file_name)
|
|
51
|
+
|
|
52
|
+
def focus_agent(self, agent_id: int | None):
|
|
53
|
+
self.renderer.focus_agent(agent_id)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
4
|
+
|
|
5
|
+
from mapox.envs.king_hill import KingHillConfig, KingHillEnv
|
|
6
|
+
from mapox.envs.grid_return import ReturnDiggingConfig, ReturnDiggingEnv
|
|
7
|
+
from mapox.envs.traveling_salesman import (
|
|
8
|
+
TravelingSalesmanConfig,
|
|
9
|
+
TravelingSalesmanEnv,
|
|
10
|
+
)
|
|
11
|
+
from mapox.envs.scouts import ScoutsConfig, ScoutsEnv
|
|
12
|
+
from mapox.client import GridworldClient
|
|
13
|
+
|
|
14
|
+
from mapox.client import EnvironmentClient
|
|
15
|
+
from mapox.environment import Environment
|
|
16
|
+
from mapox.wrappers.task_id_wrapper import TaskIdWrapper
|
|
17
|
+
from mapox.wrappers.multitask import MultiTaskWrapper
|
|
18
|
+
from mapox.wrappers.vector import VectorWrapper
|
|
19
|
+
|
|
20
|
+
type EnvironmentConfig = (
|
|
21
|
+
ReturnDiggingConfig
|
|
22
|
+
| TravelingSalesmanConfig
|
|
23
|
+
| ScoutsConfig
|
|
24
|
+
| KingHillConfig
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class MultiTaskEnvConfig(BaseModel):
|
|
29
|
+
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
30
|
+
num: int = 1
|
|
31
|
+
name: str
|
|
32
|
+
env: EnvironmentConfig = Field(discriminator="env_type")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MultiTaskConfig(BaseModel):
|
|
36
|
+
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
37
|
+
env_type: Literal["multi"] = "multi"
|
|
38
|
+
envs: tuple[MultiTaskEnvConfig, ...]
|
|
39
|
+
|
|
40
|
+
@field_validator("envs", mode="before")
|
|
41
|
+
@classmethod
|
|
42
|
+
def coerce_envs(cls, v):
|
|
43
|
+
# JSON gives list; accept list and turn into tuple
|
|
44
|
+
return tuple(v) if isinstance(v, list) else v
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def create_env(
|
|
48
|
+
env_config: EnvironmentConfig | MultiTaskConfig,
|
|
49
|
+
length: int,
|
|
50
|
+
vec_count: int = 1,
|
|
51
|
+
env_name: str | None = None,
|
|
52
|
+
) -> tuple[Environment, int]:
|
|
53
|
+
num_tasks = 1
|
|
54
|
+
if env_config.env_type == "multi" and env_name is not None:
|
|
55
|
+
num_tasks = len(env_config.envs)
|
|
56
|
+
for task_id, env_def in enumerate(env_config.envs):
|
|
57
|
+
if env_def.name == env_name:
|
|
58
|
+
return TaskIdWrapper(
|
|
59
|
+
create_env(env_def.env, length, vec_count=vec_count)[0], task_id
|
|
60
|
+
), num_tasks
|
|
61
|
+
raise ValueError("Could not find environment matching env_name")
|
|
62
|
+
|
|
63
|
+
match env_config.env_type:
|
|
64
|
+
case "multi":
|
|
65
|
+
out_envs = []
|
|
66
|
+
out_env_names = []
|
|
67
|
+
num_tasks = len(env_config.envs)
|
|
68
|
+
for env_def in env_config.envs:
|
|
69
|
+
out_envs.append(create_env(env_def.env, length, env_def.num)[0])
|
|
70
|
+
out_env_names.append(env_def.name)
|
|
71
|
+
|
|
72
|
+
env = MultiTaskWrapper(tuple(out_envs), tuple(out_env_names))
|
|
73
|
+
case "return_digging":
|
|
74
|
+
env = ReturnDiggingEnv(env_config, length)
|
|
75
|
+
case "scouts":
|
|
76
|
+
env = ScoutsEnv(env_config, length)
|
|
77
|
+
case "traveling_salesman":
|
|
78
|
+
env = TravelingSalesmanEnv(env_config, length)
|
|
79
|
+
case "king_hill":
|
|
80
|
+
env = KingHillEnv(env_config, length)
|
|
81
|
+
case _:
|
|
82
|
+
raise ValueError(f"Unknown environment type: {env_config.env_type}")
|
|
83
|
+
|
|
84
|
+
if vec_count > 1:
|
|
85
|
+
env = VectorWrapper(env, vec_count)
|
|
86
|
+
|
|
87
|
+
return env, num_tasks
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def create_client[State](env: Environment[State]) -> EnvironmentClient[State]:
|
|
91
|
+
# use the grid client for all environments for now
|
|
92
|
+
return GridworldClient(env)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
from jax import Array
|
|
7
|
+
|
|
8
|
+
from mapox.timestep import TimeStep
|
|
9
|
+
from mapox.specs import ObservationSpec, ActionSpec
|
|
10
|
+
import enum
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class StepType(enum.IntEnum):
|
|
14
|
+
FIRST = 0
|
|
15
|
+
MID = 1
|
|
16
|
+
LAST = 2
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Environment[State](ABC):
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def reset(self, rng_key: Array) -> tuple[State, TimeStep]: ...
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def step(
|
|
25
|
+
self, state: State, action: Array, rng_key: Array
|
|
26
|
+
) -> tuple[State, TimeStep]: ...
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def create_placeholder_logs(self) -> dict[str, Any]: ...
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def create_logs(self, state) -> dict[str, Any]: ...
|
|
33
|
+
|
|
34
|
+
@cached_property
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def observation_spec(self) -> ObservationSpec: ...
|
|
37
|
+
|
|
38
|
+
@cached_property
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def action_spec(self) -> ActionSpec: ...
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def num_agents(self) -> int: ...
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def is_jittable(self) -> bool: ...
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def teams(self) -> jax.Array | None:
|
|
52
|
+
return None
|
|
File without changes
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from jax import numpy as jnp
|
|
2
|
+
|
|
3
|
+
from mapox.specs import ObservationSpec
|
|
4
|
+
|
|
5
|
+
NUM_TYPES = 15
|
|
6
|
+
|
|
7
|
+
# Unified tile ids across gridworld environments
|
|
8
|
+
TILE_EMPTY = 0 # empty space
|
|
9
|
+
TILE_WALL = 1 # permanent wall
|
|
10
|
+
TILE_DESTRUCTIBLE_WALL = 2 # destructible
|
|
11
|
+
TILE_FLAG = 3 # typical goal tile
|
|
12
|
+
TILE_FLAG_UNLOCKED = 4 # used for the scouting environment where the the flag gets made available for taking
|
|
13
|
+
|
|
14
|
+
# TILE_FLAG_BLUE_TEAM = 5
|
|
15
|
+
# TILE_FLAG_RED_TEAM = 6
|
|
16
|
+
|
|
17
|
+
# agents are observed like tiles
|
|
18
|
+
AGENT_GENERIC = 5 # typical agent
|
|
19
|
+
AGENT_SCOUT = 6 # scout agent (scout env)
|
|
20
|
+
AGENT_HARVESTER = 7 # harvester agent (scout env)
|
|
21
|
+
|
|
22
|
+
AGENT_KNIGHT = 8
|
|
23
|
+
AGENT_ARCHER = 9
|
|
24
|
+
|
|
25
|
+
TILE_DECOR_1 = 10
|
|
26
|
+
TILE_DECOR_2 = 11
|
|
27
|
+
TILE_DECOR_3 = 12
|
|
28
|
+
TILE_DECOR_4 = 13
|
|
29
|
+
|
|
30
|
+
TILE_ARROW = 14
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Actions
|
|
34
|
+
NUM_ACTIONS = 7
|
|
35
|
+
|
|
36
|
+
MOVE_UP = 0
|
|
37
|
+
MOVE_RIGHT = 1
|
|
38
|
+
MOVE_DOWN = 2
|
|
39
|
+
MOVE_LEFT = 3
|
|
40
|
+
STAY = 4
|
|
41
|
+
PRIMARY_ACTION = 5
|
|
42
|
+
DIG_ACTION = 6
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
DIRECTIONS = jnp.array([[0, 1], [1, 0], [0, -1], [-1, 0]], dtype=jnp.int32)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def make_obs_spec(width: int, height: int) -> ObservationSpec:
|
|
49
|
+
# CHANNELS:
|
|
50
|
+
# TILE
|
|
51
|
+
# DIRECTION
|
|
52
|
+
# TEAM ID
|
|
53
|
+
# HEALTH
|
|
54
|
+
obs_spec = ObservationSpec(
|
|
55
|
+
dtype=jnp.int8,
|
|
56
|
+
shape=(width, height, 4),
|
|
57
|
+
max_value=(
|
|
58
|
+
NUM_TYPES,
|
|
59
|
+
5, # none, up, right, down, left,
|
|
60
|
+
3, # none, red, blue
|
|
61
|
+
3, # 0, 1, 2
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return obs_spec
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def make_action_mask(actions: list[int], num_agents: int):
|
|
69
|
+
mask = [False] * NUM_ACTIONS
|
|
70
|
+
|
|
71
|
+
for action in actions:
|
|
72
|
+
mask[action] = True
|
|
73
|
+
|
|
74
|
+
mask_array = jnp.array(mask, jnp.bool)
|
|
75
|
+
return jnp.repeat(mask_array[None, :], num_agents, axis=0)
|