mapox 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mapox/__init__.py +30 -0
- mapox/client.py +53 -0
- mapox/config.py +92 -0
- mapox/environment.py +52 -0
- mapox/envs/__init__.py +0 -0
- mapox/envs/constance.py +75 -0
- mapox/envs/grid_return.py +389 -0
- mapox/envs/king_hill.py +536 -0
- mapox/envs/scouts.py +357 -0
- mapox/envs/traveling_salesman.py +296 -0
- mapox/map_generator.py +137 -0
- mapox/renderer.py +318 -0
- mapox/specs.py +19 -0
- mapox/timestep.py +21 -0
- mapox/utils/__init__.py +0 -0
- mapox/utils/encode_one_hot.py +21 -0
- mapox/utils/video_writter.py +50 -0
- mapox/wrappers/__init__.py +0 -0
- mapox/wrappers/multitask.py +94 -0
- mapox/wrappers/task_id_wrapper.py +23 -0
- mapox/wrappers/vector.py +68 -0
- mapox-0.1.0.dist-info/METADATA +17 -0
- mapox-0.1.0.dist-info/RECORD +24 -0
- mapox-0.1.0.dist-info/WHEEL +4 -0
mapox/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""MAPOX: Multi-Agent Partially Observable gridworlds in JAX"""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.0"
|
|
4
|
+
|
|
5
|
+
from mapox.timestep import TimeStep
|
|
6
|
+
from mapox.environment import Environment
|
|
7
|
+
from mapox.specs import ActionSpec, ObservationSpec
|
|
8
|
+
from mapox.wrappers.multitask import MultiTaskWrapper
|
|
9
|
+
from mapox.wrappers.vector import VectorWrapper
|
|
10
|
+
from mapox.client import GridworldClient
|
|
11
|
+
from mapox.utils.encode_one_hot import concat_one_hot
|
|
12
|
+
from mapox.config import create_env, create_client, ScoutsConfig, TravelingSalesmanConfig, KingHillConfig, MultiTaskConfig, EnvironmentConfig
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"TimeStep",
|
|
16
|
+
"ActionSpec",
|
|
17
|
+
"ObservationSpec",
|
|
18
|
+
"Environment",
|
|
19
|
+
"MultiTaskWrapper",
|
|
20
|
+
"VectorWrapper",
|
|
21
|
+
"GridworldClient",
|
|
22
|
+
"concat_one_hot",
|
|
23
|
+
"create_env",
|
|
24
|
+
"create_client",
|
|
25
|
+
"ScoutsConfig",
|
|
26
|
+
"TravelingSalesmanConfig",
|
|
27
|
+
"KingHillConfig",
|
|
28
|
+
"MultiTaskConfig",
|
|
29
|
+
"EnvironmentConfig"
|
|
30
|
+
]
|
mapox/client.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import pygame
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
from mapox.renderer import GridworldRenderer, GridRenderState
|
|
5
|
+
from mapox.timestep import TimeStep
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EnvironmentClient[State](ABC):
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def render(self, state: State, timestep: TimeStep): ...
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def save_video(self): ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GridworldClient:
|
|
17
|
+
"""EnvironmentClient that renders via GridworldRenderer using per-env adapters."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self, env, screen_width: int = 960, screen_height: int = 960, fps: int = 10
|
|
21
|
+
):
|
|
22
|
+
assert hasattr(env, "get_render_state"), (
|
|
23
|
+
"Env must implement get_render_state(state)"
|
|
24
|
+
)
|
|
25
|
+
assert hasattr(env, "get_render_settings"), (
|
|
26
|
+
"Env must implement get_render_settings()"
|
|
27
|
+
)
|
|
28
|
+
self.env = env
|
|
29
|
+
self.renderer = GridworldRenderer(
|
|
30
|
+
screen_width=screen_width, screen_height=screen_height, fps=fps
|
|
31
|
+
)
|
|
32
|
+
self.renderer.set_env(env.get_render_settings())
|
|
33
|
+
|
|
34
|
+
def render(self, state, timestep):
|
|
35
|
+
rs: GridRenderState = self.env.get_render_state(state)
|
|
36
|
+
self.renderer.render(rs)
|
|
37
|
+
|
|
38
|
+
def render_pov(self, state, timestep):
|
|
39
|
+
"""Render only the focused agent's point-of-view, filling the screen."""
|
|
40
|
+
rs: GridRenderState = self.env.get_render_state(state)
|
|
41
|
+
self.renderer.render_agent_view(rs)
|
|
42
|
+
|
|
43
|
+
def handle_event(self, event: pygame.event.Event) -> bool:
|
|
44
|
+
return self.renderer.handle_event(event)
|
|
45
|
+
|
|
46
|
+
def record_frame(self):
|
|
47
|
+
self.renderer.record_frame()
|
|
48
|
+
|
|
49
|
+
def save_video(self, file_name: str):
|
|
50
|
+
self.renderer.save_video(file_name)
|
|
51
|
+
|
|
52
|
+
def focus_agent(self, agent_id: int | None):
|
|
53
|
+
self.renderer.focus_agent(agent_id)
|
mapox/config.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
4
|
+
|
|
5
|
+
from mapox.envs.king_hill import KingHillConfig, KingHillEnv
|
|
6
|
+
from mapox.envs.grid_return import ReturnDiggingConfig, ReturnDiggingEnv
|
|
7
|
+
from mapox.envs.traveling_salesman import (
|
|
8
|
+
TravelingSalesmanConfig,
|
|
9
|
+
TravelingSalesmanEnv,
|
|
10
|
+
)
|
|
11
|
+
from mapox.envs.scouts import ScoutsConfig, ScoutsEnv
|
|
12
|
+
from mapox.client import GridworldClient
|
|
13
|
+
|
|
14
|
+
from mapox.client import EnvironmentClient
|
|
15
|
+
from mapox.environment import Environment
|
|
16
|
+
from mapox.wrappers.task_id_wrapper import TaskIdWrapper
|
|
17
|
+
from mapox.wrappers.multitask import MultiTaskWrapper
|
|
18
|
+
from mapox.wrappers.vector import VectorWrapper
|
|
19
|
+
|
|
20
|
+
type EnvironmentConfig = (
|
|
21
|
+
ReturnDiggingConfig
|
|
22
|
+
| TravelingSalesmanConfig
|
|
23
|
+
| ScoutsConfig
|
|
24
|
+
| KingHillConfig
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class MultiTaskEnvConfig(BaseModel):
|
|
29
|
+
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
30
|
+
num: int = 1
|
|
31
|
+
name: str
|
|
32
|
+
env: EnvironmentConfig = Field(discriminator="env_type")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MultiTaskConfig(BaseModel):
|
|
36
|
+
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
37
|
+
env_type: Literal["multi"] = "multi"
|
|
38
|
+
envs: tuple[MultiTaskEnvConfig, ...]
|
|
39
|
+
|
|
40
|
+
@field_validator("envs", mode="before")
|
|
41
|
+
@classmethod
|
|
42
|
+
def coerce_envs(cls, v):
|
|
43
|
+
# JSON gives list; accept list and turn into tuple
|
|
44
|
+
return tuple(v) if isinstance(v, list) else v
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def create_env(
|
|
48
|
+
env_config: EnvironmentConfig | MultiTaskConfig,
|
|
49
|
+
length: int,
|
|
50
|
+
vec_count: int = 1,
|
|
51
|
+
env_name: str | None = None,
|
|
52
|
+
) -> tuple[Environment, int]:
|
|
53
|
+
num_tasks = 1
|
|
54
|
+
if env_config.env_type == "multi" and env_name is not None:
|
|
55
|
+
num_tasks = len(env_config.envs)
|
|
56
|
+
for task_id, env_def in enumerate(env_config.envs):
|
|
57
|
+
if env_def.name == env_name:
|
|
58
|
+
return TaskIdWrapper(
|
|
59
|
+
create_env(env_def.env, length, vec_count=vec_count)[0], task_id
|
|
60
|
+
), num_tasks
|
|
61
|
+
raise ValueError("Could not find environment matching env_name")
|
|
62
|
+
|
|
63
|
+
match env_config.env_type:
|
|
64
|
+
case "multi":
|
|
65
|
+
out_envs = []
|
|
66
|
+
out_env_names = []
|
|
67
|
+
num_tasks = len(env_config.envs)
|
|
68
|
+
for env_def in env_config.envs:
|
|
69
|
+
out_envs.append(create_env(env_def.env, length, env_def.num)[0])
|
|
70
|
+
out_env_names.append(env_def.name)
|
|
71
|
+
|
|
72
|
+
env = MultiTaskWrapper(tuple(out_envs), tuple(out_env_names))
|
|
73
|
+
case "return_digging":
|
|
74
|
+
env = ReturnDiggingEnv(env_config, length)
|
|
75
|
+
case "scouts":
|
|
76
|
+
env = ScoutsEnv(env_config, length)
|
|
77
|
+
case "traveling_salesman":
|
|
78
|
+
env = TravelingSalesmanEnv(env_config, length)
|
|
79
|
+
case "king_hill":
|
|
80
|
+
env = KingHillEnv(env_config, length)
|
|
81
|
+
case _:
|
|
82
|
+
raise ValueError(f"Unknown environment type: {env_config.env_type}")
|
|
83
|
+
|
|
84
|
+
if vec_count > 1:
|
|
85
|
+
env = VectorWrapper(env, vec_count)
|
|
86
|
+
|
|
87
|
+
return env, num_tasks
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def create_client[State](env: Environment[State]) -> EnvironmentClient[State]:
|
|
91
|
+
# use the grid client for all environments for now
|
|
92
|
+
return GridworldClient(env)
|
mapox/environment.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
from jax import Array
|
|
7
|
+
|
|
8
|
+
from mapox.timestep import TimeStep
|
|
9
|
+
from mapox.specs import ObservationSpec, ActionSpec
|
|
10
|
+
import enum
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class StepType(enum.IntEnum):
|
|
14
|
+
FIRST = 0
|
|
15
|
+
MID = 1
|
|
16
|
+
LAST = 2
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Environment[State](ABC):
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def reset(self, rng_key: Array) -> tuple[State, TimeStep]: ...
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def step(
|
|
25
|
+
self, state: State, action: Array, rng_key: Array
|
|
26
|
+
) -> tuple[State, TimeStep]: ...
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def create_placeholder_logs(self) -> dict[str, Any]: ...
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def create_logs(self, state) -> dict[str, Any]: ...
|
|
33
|
+
|
|
34
|
+
@cached_property
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def observation_spec(self) -> ObservationSpec: ...
|
|
37
|
+
|
|
38
|
+
@cached_property
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def action_spec(self) -> ActionSpec: ...
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def num_agents(self) -> int: ...
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def is_jittable(self) -> bool: ...
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def teams(self) -> jax.Array | None:
|
|
52
|
+
return None
|
mapox/envs/__init__.py
ADDED
|
File without changes
|
mapox/envs/constance.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from jax import numpy as jnp
|
|
2
|
+
|
|
3
|
+
from mapox.specs import ObservationSpec
|
|
4
|
+
|
|
5
|
+
NUM_TYPES = 15
|
|
6
|
+
|
|
7
|
+
# Unified tile ids across gridworld environments
|
|
8
|
+
TILE_EMPTY = 0 # empty space
|
|
9
|
+
TILE_WALL = 1 # permanent wall
|
|
10
|
+
TILE_DESTRUCTIBLE_WALL = 2 # destructible
|
|
11
|
+
TILE_FLAG = 3 # typical goal tile
|
|
12
|
+
TILE_FLAG_UNLOCKED = 4 # used for the scouting environment where the the flag gets made available for taking
|
|
13
|
+
|
|
14
|
+
# TILE_FLAG_BLUE_TEAM = 5
|
|
15
|
+
# TILE_FLAG_RED_TEAM = 6
|
|
16
|
+
|
|
17
|
+
# agents are observed like tiles
|
|
18
|
+
AGENT_GENERIC = 5 # typical agent
|
|
19
|
+
AGENT_SCOUT = 6 # scout agent (scout env)
|
|
20
|
+
AGENT_HARVESTER = 7 # harvester agent (scout env)
|
|
21
|
+
|
|
22
|
+
AGENT_KNIGHT = 8
|
|
23
|
+
AGENT_ARCHER = 9
|
|
24
|
+
|
|
25
|
+
TILE_DECOR_1 = 10
|
|
26
|
+
TILE_DECOR_2 = 11
|
|
27
|
+
TILE_DECOR_3 = 12
|
|
28
|
+
TILE_DECOR_4 = 13
|
|
29
|
+
|
|
30
|
+
TILE_ARROW = 14
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Actions
|
|
34
|
+
NUM_ACTIONS = 7
|
|
35
|
+
|
|
36
|
+
MOVE_UP = 0
|
|
37
|
+
MOVE_RIGHT = 1
|
|
38
|
+
MOVE_DOWN = 2
|
|
39
|
+
MOVE_LEFT = 3
|
|
40
|
+
STAY = 4
|
|
41
|
+
PRIMARY_ACTION = 5
|
|
42
|
+
DIG_ACTION = 6
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
DIRECTIONS = jnp.array([[0, 1], [1, 0], [0, -1], [-1, 0]], dtype=jnp.int32)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def make_obs_spec(width: int, height: int) -> ObservationSpec:
|
|
49
|
+
# CHANNELS:
|
|
50
|
+
# TILE
|
|
51
|
+
# DIRECTION
|
|
52
|
+
# TEAM ID
|
|
53
|
+
# HEALTH
|
|
54
|
+
obs_spec = ObservationSpec(
|
|
55
|
+
dtype=jnp.int8,
|
|
56
|
+
shape=(width, height, 4),
|
|
57
|
+
max_value=(
|
|
58
|
+
NUM_TYPES,
|
|
59
|
+
5, # none, up, right, down, left,
|
|
60
|
+
3, # none, red, blue
|
|
61
|
+
3, # 0, 1, 2
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return obs_spec
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def make_action_mask(actions: list[int], num_agents: int):
|
|
69
|
+
mask = [False] * NUM_ACTIONS
|
|
70
|
+
|
|
71
|
+
for action in actions:
|
|
72
|
+
mask[action] = True
|
|
73
|
+
|
|
74
|
+
mask_array = jnp.array(mask, jnp.bool)
|
|
75
|
+
return jnp.repeat(mask_array[None, :], num_agents, axis=0)
|
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
from functools import cached_property, partial
|
|
2
|
+
from typing import NamedTuple, Literal
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
from jax import numpy as jnp
|
|
6
|
+
import numpy as np
|
|
7
|
+
from pydantic import BaseModel, ConfigDict
|
|
8
|
+
|
|
9
|
+
from mapox.map_generator import (
|
|
10
|
+
fractal_noise,
|
|
11
|
+
generate_decor_tiles,
|
|
12
|
+
choose_positions,
|
|
13
|
+
)
|
|
14
|
+
from mapox.environment import Environment
|
|
15
|
+
from mapox.specs import DiscreteActionSpec, ObservationSpec
|
|
16
|
+
from mapox.timestep import TimeStep
|
|
17
|
+
from mapox.envs.renderer import GridRenderSettings, GridRenderState
|
|
18
|
+
import mapox.envs.constance as GW
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ReturnDiggingConfig(BaseModel):
|
|
22
|
+
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
23
|
+
env_type: Literal["return_digging"] = "return_digging"
|
|
24
|
+
|
|
25
|
+
num_agents: int = 1
|
|
26
|
+
num_flags: int = 1
|
|
27
|
+
|
|
28
|
+
width: int = 40
|
|
29
|
+
height: int = 40
|
|
30
|
+
view_width: int = 11
|
|
31
|
+
view_height: int = 11
|
|
32
|
+
|
|
33
|
+
mapgen_threshold: float = 0.3
|
|
34
|
+
digging_timeout: int = 5
|
|
35
|
+
treasure_reward: float = 1.0
|
|
36
|
+
|
|
37
|
+
eval_map: bool = False
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ReturnDiggingState(NamedTuple):
|
|
41
|
+
agents_pos: jax.Array
|
|
42
|
+
agents_timeout: jax.Array
|
|
43
|
+
found_reward: jax.Array
|
|
44
|
+
|
|
45
|
+
time: jax.Array
|
|
46
|
+
|
|
47
|
+
map: jax.Array
|
|
48
|
+
spawn_pos: jax.Array
|
|
49
|
+
spawn_count: jax.Array
|
|
50
|
+
|
|
51
|
+
rewards: jax.Array
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ReturnDiggingEnv(Environment[ReturnDiggingState]):
|
|
55
|
+
def __init__(self, config: ReturnDiggingConfig, length: int) -> None:
|
|
56
|
+
super().__init__()
|
|
57
|
+
|
|
58
|
+
self._config = config
|
|
59
|
+
self._length = length
|
|
60
|
+
self._num_agents = config.num_agents
|
|
61
|
+
self.num_flags = config.num_flags
|
|
62
|
+
|
|
63
|
+
self.unpadded_width = config.width
|
|
64
|
+
self.unpadded_height = config.height
|
|
65
|
+
|
|
66
|
+
self.view_width = config.view_width
|
|
67
|
+
self.view_height = config.view_height
|
|
68
|
+
self.pad_width = self.view_width // 2
|
|
69
|
+
self.pad_height = self.view_height // 2
|
|
70
|
+
|
|
71
|
+
self.width = self.unpadded_width + self.pad_width
|
|
72
|
+
self.height = self.unpadded_height + self.pad_height
|
|
73
|
+
|
|
74
|
+
self.mapgen_threshold = config.mapgen_threshold
|
|
75
|
+
self.digging_timeout = config.digging_timeout
|
|
76
|
+
self.treasure_reward = config.treasure_reward
|
|
77
|
+
|
|
78
|
+
self._action_mask = GW.make_action_mask(
|
|
79
|
+
[
|
|
80
|
+
GW.MOVE_UP,
|
|
81
|
+
GW.MOVE_RIGHT,
|
|
82
|
+
GW.MOVE_DOWN,
|
|
83
|
+
GW.MOVE_LEFT,
|
|
84
|
+
],
|
|
85
|
+
self.num_agents,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _generate_map(self, rng_key):
|
|
89
|
+
walls_key, decor_key, rng_key = jax.random.split(rng_key, 3)
|
|
90
|
+
noise = fractal_noise(
|
|
91
|
+
self.unpadded_width, self.unpadded_height, [2, 4, 5, 8, 10], walls_key
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
tiles = generate_decor_tiles(
|
|
95
|
+
self.unpadded_width, self.unpadded_height, decor_key
|
|
96
|
+
)
|
|
97
|
+
tiles = jnp.where(noise > 0.05, jnp.int8(GW.TILE_DESTRUCTIBLE_WALL), tiles)
|
|
98
|
+
|
|
99
|
+
# get the empty tiles for spawning
|
|
100
|
+
x_spawns, y_spawns = jnp.where(
|
|
101
|
+
tiles == GW.TILE_EMPTY,
|
|
102
|
+
size=self.unpadded_width * self.unpadded_height,
|
|
103
|
+
fill_value=jnp.int8(-1),
|
|
104
|
+
)
|
|
105
|
+
spawn_count = jnp.sum(tiles == GW.TILE_EMPTY)
|
|
106
|
+
|
|
107
|
+
# pad the tiles
|
|
108
|
+
tiles = jnp.pad(
|
|
109
|
+
tiles,
|
|
110
|
+
pad_width=(
|
|
111
|
+
(self.pad_width, self.pad_width),
|
|
112
|
+
(self.pad_height, self.pad_height),
|
|
113
|
+
),
|
|
114
|
+
mode="constant",
|
|
115
|
+
constant_values=GW.TILE_WALL,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# pad the empty tiles
|
|
119
|
+
y_spawns = y_spawns + self.pad_height
|
|
120
|
+
x_spawns = x_spawns + self.pad_width
|
|
121
|
+
spawn_pos = jnp.stack((x_spawns, y_spawns), axis=1)
|
|
122
|
+
|
|
123
|
+
return tiles, spawn_pos, spawn_count
|
|
124
|
+
|
|
125
|
+
def reset(self, rng_key: jax.Array) -> tuple[ReturnDiggingState, TimeStep]:
|
|
126
|
+
map_key, pos_key = jax.random.split(rng_key)
|
|
127
|
+
|
|
128
|
+
map, spawn_pos, spawn_count = self._generate_map(map_key)
|
|
129
|
+
|
|
130
|
+
unpadded_map = map[
|
|
131
|
+
self.pad_width : -self.pad_width, self.pad_height : -self.pad_height
|
|
132
|
+
]
|
|
133
|
+
|
|
134
|
+
pos_x, pos_y = choose_positions(
|
|
135
|
+
unpadded_map,
|
|
136
|
+
self.num_flags + self.num_agents,
|
|
137
|
+
pos_key,
|
|
138
|
+
replace=False,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
pos_x = pos_x + self.pad_width
|
|
142
|
+
pos_y = pos_y + self.pad_height
|
|
143
|
+
positions = jnp.stack((pos_x, pos_y), axis=1)
|
|
144
|
+
flag_pos = positions[: self.num_flags]
|
|
145
|
+
agents_pos = positions[self.num_flags :]
|
|
146
|
+
|
|
147
|
+
if self._config.eval_map:
|
|
148
|
+
o = map
|
|
149
|
+
map = map.at[36:45, 10:30].set(GW.TILE_DESTRUCTIBLE_WALL)
|
|
150
|
+
map = map.at[42:45, 17:23].set(o[42:45, 17:23])
|
|
151
|
+
agents_pos = agents_pos.at[0].set([44, 22])
|
|
152
|
+
map = map.at[43, 18].set(GW.TILE_FLAG)
|
|
153
|
+
else:
|
|
154
|
+
map = map.at[flag_pos[:, 0], flag_pos[:, 1]].set(GW.TILE_FLAG)
|
|
155
|
+
|
|
156
|
+
state = ReturnDiggingState(
|
|
157
|
+
map=map,
|
|
158
|
+
spawn_pos=spawn_pos,
|
|
159
|
+
spawn_count=spawn_count,
|
|
160
|
+
agents_pos=agents_pos,
|
|
161
|
+
agents_timeout=jnp.zeros((self.num_agents,), dtype=jnp.int32),
|
|
162
|
+
found_reward=jnp.zeros((self.num_agents,), dtype=jnp.bool_),
|
|
163
|
+
time=jnp.int32(0),
|
|
164
|
+
rewards=jnp.float32(0.0),
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
actions = jnp.zeros((self.num_agents,), dtype=jnp.int32)
|
|
168
|
+
rewards = jnp.zeros((self.num_agents,), dtype=jnp.float32)
|
|
169
|
+
|
|
170
|
+
return state, self.encode_observations(state, actions, rewards)
|
|
171
|
+
|
|
172
|
+
def load_map(self, map: str):
|
|
173
|
+
tiles = np.zeros((self.unpadded_width, self.unpadded_height), dtype=np.int8)
|
|
174
|
+
|
|
175
|
+
x = 0
|
|
176
|
+
y = self.unpadded_height
|
|
177
|
+
|
|
178
|
+
agent_positions = []
|
|
179
|
+
spawn_positions = []
|
|
180
|
+
|
|
181
|
+
for c in map:
|
|
182
|
+
if c == "\n":
|
|
183
|
+
x = 0
|
|
184
|
+
y -= 1
|
|
185
|
+
else:
|
|
186
|
+
match x:
|
|
187
|
+
case "x":
|
|
188
|
+
tiles[x, y] = GW.TILE_DESTRUCTIBLE_WALL
|
|
189
|
+
case "a":
|
|
190
|
+
agent_positions.append(
|
|
191
|
+
[self.pad_width + x, self.pad_height + y]
|
|
192
|
+
)
|
|
193
|
+
case "f":
|
|
194
|
+
tiles[x, y] = GW.TILE_FLAG
|
|
195
|
+
case _:
|
|
196
|
+
spawn_positions.append(
|
|
197
|
+
[self.pad_width + x, self.pad_height + y]
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
x += 1
|
|
201
|
+
|
|
202
|
+
tiles = jnp.asarray(tiles)
|
|
203
|
+
# pad the tiles
|
|
204
|
+
tiles = jnp.pad(
|
|
205
|
+
tiles,
|
|
206
|
+
pad_width=(
|
|
207
|
+
(self.pad_width, self.pad_width),
|
|
208
|
+
(self.pad_height, self.pad_height),
|
|
209
|
+
),
|
|
210
|
+
mode="constant",
|
|
211
|
+
constant_values=GW.TILE_WALL,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
state = ReturnDiggingState(
|
|
215
|
+
map=tiles,
|
|
216
|
+
spawn_pos=jnp.array(spawn_positions, jnp.int32),
|
|
217
|
+
spawn_count=jnp.int32(len(spawn_positions)),
|
|
218
|
+
agents_pos=jnp.arange(agent_positions, jnp.int32),
|
|
219
|
+
agents_timeout=jnp.zeros((self.num_agents,), dtype=jnp.int32),
|
|
220
|
+
found_reward=jnp.zeros((self.num_agents,), dtype=jnp.bool_),
|
|
221
|
+
time=jnp.int32(0),
|
|
222
|
+
rewards=jnp.float32(0.0),
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
actions = jnp.zeros((self.num_agents,), dtype=jnp.int32)
|
|
226
|
+
rewards = jnp.zeros((self.num_agents,), dtype=jnp.float32)
|
|
227
|
+
|
|
228
|
+
return state, self.encode_observations(state, actions, rewards)
|
|
229
|
+
|
|
230
|
+
@cached_property
|
|
231
|
+
def observation_spec(self) -> ObservationSpec:
|
|
232
|
+
return GW.make_obs_spec(self.view_width, self.view_height)
|
|
233
|
+
|
|
234
|
+
@cached_property
|
|
235
|
+
def action_spec(self) -> DiscreteActionSpec:
|
|
236
|
+
return DiscreteActionSpec(num_actions=GW.NUM_ACTIONS)
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def is_jittable(self) -> bool:
|
|
240
|
+
return True
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def num_agents(self) -> int:
|
|
244
|
+
return self._num_agents
|
|
245
|
+
|
|
246
|
+
def step(
|
|
247
|
+
self, state: ReturnDiggingState, action: jax.Array, rng_key: jax.Array
|
|
248
|
+
) -> tuple[ReturnDiggingState, TimeStep]:
|
|
249
|
+
@partial(jax.vmap, in_axes=(0, 0, 0, 0), out_axes=(0, 0, 0, 0))
|
|
250
|
+
def _step_agent(local_position, timeout, local_action, random_position):
|
|
251
|
+
def _step_timeout(local_position, timeout, local_action, random_position):
|
|
252
|
+
return local_position, local_position, timeout - 1, 0.0
|
|
253
|
+
|
|
254
|
+
def _step_move(local_position, timeout, local_action, random_position):
|
|
255
|
+
target_pos = local_position + GW.DIRECTIONS[local_action]
|
|
256
|
+
|
|
257
|
+
new_tile = state.map[target_pos[0], target_pos[1]]
|
|
258
|
+
|
|
259
|
+
# don't move if we are moving into a wall
|
|
260
|
+
new_pos = jnp.where(
|
|
261
|
+
jnp.logical_or(
|
|
262
|
+
new_tile == GW.TILE_WALL, new_tile == GW.TILE_DESTRUCTIBLE_WALL
|
|
263
|
+
),
|
|
264
|
+
local_position,
|
|
265
|
+
target_pos,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
found_treasure = new_tile == GW.TILE_FLAG
|
|
269
|
+
reward = jnp.where(found_treasure, self.treasure_reward, 0.0)
|
|
270
|
+
|
|
271
|
+
# randomize position if the agent finds the reward
|
|
272
|
+
new_pos = jnp.where(found_treasure, random_position, new_pos)
|
|
273
|
+
|
|
274
|
+
# sets a timeout of the tile is dug
|
|
275
|
+
timeout = jnp.where(
|
|
276
|
+
new_tile == GW.TILE_DESTRUCTIBLE_WALL, self.digging_timeout, 0
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return new_pos, target_pos, timeout, reward
|
|
280
|
+
|
|
281
|
+
return jax.lax.cond(
|
|
282
|
+
timeout > 0,
|
|
283
|
+
_step_timeout,
|
|
284
|
+
_step_move,
|
|
285
|
+
local_position,
|
|
286
|
+
timeout,
|
|
287
|
+
local_action,
|
|
288
|
+
random_position,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
random_positions = state.spawn_pos[
|
|
292
|
+
jax.random.randint(
|
|
293
|
+
rng_key, (self._num_agents,), minval=0, maxval=state.spawn_count
|
|
294
|
+
)
|
|
295
|
+
]
|
|
296
|
+
new_position, target_pos, timeout, rewards = _step_agent(
|
|
297
|
+
state.agents_pos, state.agents_timeout, action, random_positions
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# dig actions
|
|
301
|
+
target_tiles = state.map[target_pos[:, 0], target_pos[:, 1]]
|
|
302
|
+
map = state.map.at[target_pos[:, 0], target_pos[:, 1]].set(
|
|
303
|
+
jnp.where(
|
|
304
|
+
target_tiles == GW.TILE_DESTRUCTIBLE_WALL, GW.TILE_EMPTY, target_tiles
|
|
305
|
+
)
|
|
306
|
+
)
|
|
307
|
+
# /dig actions
|
|
308
|
+
|
|
309
|
+
state = state._replace(
|
|
310
|
+
agents_pos=new_position,
|
|
311
|
+
agents_timeout=timeout,
|
|
312
|
+
found_reward=jnp.logical_or(state.found_reward, rewards),
|
|
313
|
+
time=state.time + 1,
|
|
314
|
+
rewards=state.rewards + jnp.mean(rewards),
|
|
315
|
+
map=map,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return state, self.encode_observations(state, action, rewards)
|
|
319
|
+
|
|
320
|
+
def _render_tiles(self, state: ReturnDiggingState):
|
|
321
|
+
tiles = state.map
|
|
322
|
+
tiles = tiles.at[state.agents_pos[:, 0], state.agents_pos[:, 1]].set(
|
|
323
|
+
GW.AGENT_GENERIC
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
directions = jnp.zeros_like(tiles, dtype=jnp.int8)
|
|
327
|
+
teams = jnp.zeros_like(tiles, dtype=jnp.int8)
|
|
328
|
+
health = jnp.zeros_like(tiles, dtype=jnp.int8)
|
|
329
|
+
|
|
330
|
+
return jnp.concatenate(
|
|
331
|
+
(
|
|
332
|
+
tiles[..., None],
|
|
333
|
+
directions[..., None],
|
|
334
|
+
teams[..., None],
|
|
335
|
+
health[..., None],
|
|
336
|
+
),
|
|
337
|
+
axis=-1,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
def encode_observations(
|
|
341
|
+
self, state: ReturnDiggingState, actions, rewards
|
|
342
|
+
) -> TimeStep:
|
|
343
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
|
344
|
+
def _encode_view(tiles, positions):
|
|
345
|
+
return jax.lax.dynamic_slice(
|
|
346
|
+
tiles,
|
|
347
|
+
(
|
|
348
|
+
positions[0] - self.view_width // 2,
|
|
349
|
+
positions[1] - self.view_height // 2,
|
|
350
|
+
0,
|
|
351
|
+
),
|
|
352
|
+
(self.view_width, self.view_height, self.observation_spec.shape[-1]),
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
tiles = self._render_tiles(state)
|
|
356
|
+
view = _encode_view(tiles, state.agents_pos)
|
|
357
|
+
|
|
358
|
+
time = jnp.repeat(state.time[None], self.num_agents, axis=0)
|
|
359
|
+
|
|
360
|
+
return TimeStep(
|
|
361
|
+
obs=view,
|
|
362
|
+
time=time,
|
|
363
|
+
last_action=actions,
|
|
364
|
+
last_reward=rewards,
|
|
365
|
+
action_mask=self._action_mask,
|
|
366
|
+
terminated=jnp.equal(time, self._length - 1),
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def create_placeholder_logs(self):
|
|
370
|
+
return {"rewards": jnp.float32(0.0)}
|
|
371
|
+
|
|
372
|
+
def create_logs(self, state: ReturnDiggingState):
|
|
373
|
+
return {"rewards": state.rewards}
|
|
374
|
+
|
|
375
|
+
def get_render_state(self, state: ReturnDiggingState) -> GridRenderState:
|
|
376
|
+
tiles = self._render_tiles(state)
|
|
377
|
+
|
|
378
|
+
return GridRenderState(
|
|
379
|
+
tilemap=tiles,
|
|
380
|
+
agent_positions=state.agents_pos,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
def get_render_settings(self) -> GridRenderSettings:
|
|
384
|
+
return GridRenderSettings(
|
|
385
|
+
tile_width=self.unpadded_width,
|
|
386
|
+
tile_height=self.unpadded_height,
|
|
387
|
+
view_width=self.view_width,
|
|
388
|
+
view_height=self.view_height,
|
|
389
|
+
)
|