mapox 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mapox-0.1.0/.gitignore ADDED
@@ -0,0 +1,10 @@
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
@@ -0,0 +1 @@
1
+ 3.11
mapox-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,17 @@
1
+ Metadata-Version: 2.4
2
+ Name: mapox
3
+ Version: 0.1.0
4
+ Summary: Multi-Agent Partially Observable gridworlds in JAX
5
+ Author-email: Gabriel Keith <gfk1995@gmail.com>
6
+ Keywords: gridworld,jax,multi-agent,pomdp,reinforcement-learning
7
+ Classifier: Development Status :: 3 - Alpha
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
12
+ Requires-Python: >=3.11
13
+ Requires-Dist: einops>=0.8.1
14
+ Requires-Dist: jax>=0.7.2
15
+ Requires-Dist: pydantic>=2.12.5
16
+ Requires-Dist: pygame-ce>=2.5.6
17
+ Requires-Dist: python-ffmpeg>=2.0.12
mapox-0.1.0/README.md ADDED
File without changes
@@ -0,0 +1,30 @@
1
+ """MAPOX: Multi-Agent Partially Observable gridworlds in JAX"""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from mapox.timestep import TimeStep
6
+ from mapox.environment import Environment
7
+ from mapox.specs import ActionSpec, ObservationSpec
8
+ from mapox.wrappers.multitask import MultiTaskWrapper
9
+ from mapox.wrappers.vector import VectorWrapper
10
+ from mapox.client import GridworldClient
11
+ from mapox.utils.encode_one_hot import concat_one_hot
12
+ from mapox.config import create_env, create_client, ScoutsConfig, TravelingSalesmanConfig, KingHillConfig, MultiTaskConfig, EnvironmentConfig
13
+
14
+ __all__ = [
15
+ "TimeStep",
16
+ "ActionSpec",
17
+ "ObservationSpec",
18
+ "Environment",
19
+ "MultiTaskWrapper",
20
+ "VectorWrapper",
21
+ "GridworldClient",
22
+ "concat_one_hot",
23
+ "create_env",
24
+ "create_client",
25
+ "ScoutsConfig",
26
+ "TravelingSalesmanConfig",
27
+ "KingHillConfig",
28
+ "MultiTaskConfig",
29
+ "EnvironmentConfig"
30
+ ]
@@ -0,0 +1,53 @@
1
+ import pygame
2
+ from abc import ABC, abstractmethod
3
+
4
+ from mapox.renderer import GridworldRenderer, GridRenderState
5
+ from mapox.timestep import TimeStep
6
+
7
+
8
+ class EnvironmentClient[State](ABC):
9
+ @abstractmethod
10
+ def render(self, state: State, timestep: TimeStep): ...
11
+
12
+ @abstractmethod
13
+ def save_video(self): ...
14
+
15
+
16
+ class GridworldClient:
17
+ """EnvironmentClient that renders via GridworldRenderer using per-env adapters."""
18
+
19
+ def __init__(
20
+ self, env, screen_width: int = 960, screen_height: int = 960, fps: int = 10
21
+ ):
22
+ assert hasattr(env, "get_render_state"), (
23
+ "Env must implement get_render_state(state)"
24
+ )
25
+ assert hasattr(env, "get_render_settings"), (
26
+ "Env must implement get_render_settings()"
27
+ )
28
+ self.env = env
29
+ self.renderer = GridworldRenderer(
30
+ screen_width=screen_width, screen_height=screen_height, fps=fps
31
+ )
32
+ self.renderer.set_env(env.get_render_settings())
33
+
34
+ def render(self, state, timestep):
35
+ rs: GridRenderState = self.env.get_render_state(state)
36
+ self.renderer.render(rs)
37
+
38
+ def render_pov(self, state, timestep):
39
+ """Render only the focused agent's point-of-view, filling the screen."""
40
+ rs: GridRenderState = self.env.get_render_state(state)
41
+ self.renderer.render_agent_view(rs)
42
+
43
+ def handle_event(self, event: pygame.event.Event) -> bool:
44
+ return self.renderer.handle_event(event)
45
+
46
+ def record_frame(self):
47
+ self.renderer.record_frame()
48
+
49
+ def save_video(self, file_name: str):
50
+ self.renderer.save_video(file_name)
51
+
52
+ def focus_agent(self, agent_id: int | None):
53
+ self.renderer.focus_agent(agent_id)
@@ -0,0 +1,92 @@
1
+ from typing import Literal
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
4
+
5
+ from mapox.envs.king_hill import KingHillConfig, KingHillEnv
6
+ from mapox.envs.grid_return import ReturnDiggingConfig, ReturnDiggingEnv
7
+ from mapox.envs.traveling_salesman import (
8
+ TravelingSalesmanConfig,
9
+ TravelingSalesmanEnv,
10
+ )
11
+ from mapox.envs.scouts import ScoutsConfig, ScoutsEnv
12
+ from mapox.client import GridworldClient
13
+
14
+ from mapox.client import EnvironmentClient
15
+ from mapox.environment import Environment
16
+ from mapox.wrappers.task_id_wrapper import TaskIdWrapper
17
+ from mapox.wrappers.multitask import MultiTaskWrapper
18
+ from mapox.wrappers.vector import VectorWrapper
19
+
20
+ type EnvironmentConfig = (
21
+ ReturnDiggingConfig
22
+ | TravelingSalesmanConfig
23
+ | ScoutsConfig
24
+ | KingHillConfig
25
+ )
26
+
27
+
28
+ class MultiTaskEnvConfig(BaseModel):
29
+ model_config = ConfigDict(extra="forbid", frozen=True)
30
+ num: int = 1
31
+ name: str
32
+ env: EnvironmentConfig = Field(discriminator="env_type")
33
+
34
+
35
+ class MultiTaskConfig(BaseModel):
36
+ model_config = ConfigDict(extra="forbid", frozen=True)
37
+ env_type: Literal["multi"] = "multi"
38
+ envs: tuple[MultiTaskEnvConfig, ...]
39
+
40
+ @field_validator("envs", mode="before")
41
+ @classmethod
42
+ def coerce_envs(cls, v):
43
+ # JSON gives list; accept list and turn into tuple
44
+ return tuple(v) if isinstance(v, list) else v
45
+
46
+
47
+ def create_env(
48
+ env_config: EnvironmentConfig | MultiTaskConfig,
49
+ length: int,
50
+ vec_count: int = 1,
51
+ env_name: str | None = None,
52
+ ) -> tuple[Environment, int]:
53
+ num_tasks = 1
54
+ if env_config.env_type == "multi" and env_name is not None:
55
+ num_tasks = len(env_config.envs)
56
+ for task_id, env_def in enumerate(env_config.envs):
57
+ if env_def.name == env_name:
58
+ return TaskIdWrapper(
59
+ create_env(env_def.env, length, vec_count=vec_count)[0], task_id
60
+ ), num_tasks
61
+ raise ValueError("Could not find environment matching env_name")
62
+
63
+ match env_config.env_type:
64
+ case "multi":
65
+ out_envs = []
66
+ out_env_names = []
67
+ num_tasks = len(env_config.envs)
68
+ for env_def in env_config.envs:
69
+ out_envs.append(create_env(env_def.env, length, env_def.num)[0])
70
+ out_env_names.append(env_def.name)
71
+
72
+ env = MultiTaskWrapper(tuple(out_envs), tuple(out_env_names))
73
+ case "return_digging":
74
+ env = ReturnDiggingEnv(env_config, length)
75
+ case "scouts":
76
+ env = ScoutsEnv(env_config, length)
77
+ case "traveling_salesman":
78
+ env = TravelingSalesmanEnv(env_config, length)
79
+ case "king_hill":
80
+ env = KingHillEnv(env_config, length)
81
+ case _:
82
+ raise ValueError(f"Unknown environment type: {env_config.env_type}")
83
+
84
+ if vec_count > 1:
85
+ env = VectorWrapper(env, vec_count)
86
+
87
+ return env, num_tasks
88
+
89
+
90
+ def create_client[State](env: Environment[State]) -> EnvironmentClient[State]:
91
+ # use the grid client for all environments for now
92
+ return GridworldClient(env)
@@ -0,0 +1,52 @@
1
+ from abc import ABC, abstractmethod
2
+ from functools import cached_property
3
+ from typing import Any
4
+
5
+ import jax
6
+ from jax import Array
7
+
8
+ from mapox.timestep import TimeStep
9
+ from mapox.specs import ObservationSpec, ActionSpec
10
+ import enum
11
+
12
+
13
+ class StepType(enum.IntEnum):
14
+ FIRST = 0
15
+ MID = 1
16
+ LAST = 2
17
+
18
+
19
+ class Environment[State](ABC):
20
+ @abstractmethod
21
+ def reset(self, rng_key: Array) -> tuple[State, TimeStep]: ...
22
+
23
+ @abstractmethod
24
+ def step(
25
+ self, state: State, action: Array, rng_key: Array
26
+ ) -> tuple[State, TimeStep]: ...
27
+
28
+ @abstractmethod
29
+ def create_placeholder_logs(self) -> dict[str, Any]: ...
30
+
31
+ @abstractmethod
32
+ def create_logs(self, state) -> dict[str, Any]: ...
33
+
34
+ @cached_property
35
+ @abstractmethod
36
+ def observation_spec(self) -> ObservationSpec: ...
37
+
38
+ @cached_property
39
+ @abstractmethod
40
+ def action_spec(self) -> ActionSpec: ...
41
+
42
+ @property
43
+ @abstractmethod
44
+ def num_agents(self) -> int: ...
45
+
46
+ @property
47
+ @abstractmethod
48
+ def is_jittable(self) -> bool: ...
49
+
50
+ @property
51
+ def teams(self) -> jax.Array | None:
52
+ return None
File without changes
@@ -0,0 +1,75 @@
1
+ from jax import numpy as jnp
2
+
3
+ from mapox.specs import ObservationSpec
4
+
5
+ NUM_TYPES = 15
6
+
7
+ # Unified tile ids across gridworld environments
8
+ TILE_EMPTY = 0 # empty space
9
+ TILE_WALL = 1 # permanent wall
10
+ TILE_DESTRUCTIBLE_WALL = 2 # destructible
11
+ TILE_FLAG = 3 # typical goal tile
12
+ TILE_FLAG_UNLOCKED = 4 # used for the scouting environment where the the flag gets made available for taking
13
+
14
+ # TILE_FLAG_BLUE_TEAM = 5
15
+ # TILE_FLAG_RED_TEAM = 6
16
+
17
+ # agents are observed like tiles
18
+ AGENT_GENERIC = 5 # typical agent
19
+ AGENT_SCOUT = 6 # scout agent (scout env)
20
+ AGENT_HARVESTER = 7 # harvester agent (scout env)
21
+
22
+ AGENT_KNIGHT = 8
23
+ AGENT_ARCHER = 9
24
+
25
+ TILE_DECOR_1 = 10
26
+ TILE_DECOR_2 = 11
27
+ TILE_DECOR_3 = 12
28
+ TILE_DECOR_4 = 13
29
+
30
+ TILE_ARROW = 14
31
+
32
+
33
+ # Actions
34
+ NUM_ACTIONS = 7
35
+
36
+ MOVE_UP = 0
37
+ MOVE_RIGHT = 1
38
+ MOVE_DOWN = 2
39
+ MOVE_LEFT = 3
40
+ STAY = 4
41
+ PRIMARY_ACTION = 5
42
+ DIG_ACTION = 6
43
+
44
+
45
+ DIRECTIONS = jnp.array([[0, 1], [1, 0], [0, -1], [-1, 0]], dtype=jnp.int32)
46
+
47
+
48
+ def make_obs_spec(width: int, height: int) -> ObservationSpec:
49
+ # CHANNELS:
50
+ # TILE
51
+ # DIRECTION
52
+ # TEAM ID
53
+ # HEALTH
54
+ obs_spec = ObservationSpec(
55
+ dtype=jnp.int8,
56
+ shape=(width, height, 4),
57
+ max_value=(
58
+ NUM_TYPES,
59
+ 5, # none, up, right, down, left,
60
+ 3, # none, red, blue
61
+ 3, # 0, 1, 2
62
+ ),
63
+ )
64
+
65
+ return obs_spec
66
+
67
+
68
+ def make_action_mask(actions: list[int], num_agents: int):
69
+ mask = [False] * NUM_ACTIONS
70
+
71
+ for action in actions:
72
+ mask[action] = True
73
+
74
+ mask_array = jnp.array(mask, jnp.bool)
75
+ return jnp.repeat(mask_array[None, :], num_agents, axis=0)