mapox 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mapox/__init__.py ADDED
@@ -0,0 +1,30 @@
1
+ """MAPOX: Multi-Agent Partially Observable gridworlds in JAX"""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from mapox.timestep import TimeStep
6
+ from mapox.environment import Environment
7
+ from mapox.specs import ActionSpec, ObservationSpec
8
+ from mapox.wrappers.multitask import MultiTaskWrapper
9
+ from mapox.wrappers.vector import VectorWrapper
10
+ from mapox.client import GridworldClient
11
+ from mapox.utils.encode_one_hot import concat_one_hot
12
+ from mapox.config import create_env, create_client, ScoutsConfig, TravelingSalesmanConfig, KingHillConfig, MultiTaskConfig, EnvironmentConfig
13
+
14
+ __all__ = [
15
+ "TimeStep",
16
+ "ActionSpec",
17
+ "ObservationSpec",
18
+ "Environment",
19
+ "MultiTaskWrapper",
20
+ "VectorWrapper",
21
+ "GridworldClient",
22
+ "concat_one_hot",
23
+ "create_env",
24
+ "create_client",
25
+ "ScoutsConfig",
26
+ "TravelingSalesmanConfig",
27
+ "KingHillConfig",
28
+ "MultiTaskConfig",
29
+ "EnvironmentConfig"
30
+ ]
mapox/client.py ADDED
@@ -0,0 +1,53 @@
1
+ import pygame
2
+ from abc import ABC, abstractmethod
3
+
4
+ from mapox.renderer import GridworldRenderer, GridRenderState
5
+ from mapox.timestep import TimeStep
6
+
7
+
8
+ class EnvironmentClient[State](ABC):
9
+ @abstractmethod
10
+ def render(self, state: State, timestep: TimeStep): ...
11
+
12
+ @abstractmethod
13
+ def save_video(self): ...
14
+
15
+
16
+ class GridworldClient:
17
+ """EnvironmentClient that renders via GridworldRenderer using per-env adapters."""
18
+
19
+ def __init__(
20
+ self, env, screen_width: int = 960, screen_height: int = 960, fps: int = 10
21
+ ):
22
+ assert hasattr(env, "get_render_state"), (
23
+ "Env must implement get_render_state(state)"
24
+ )
25
+ assert hasattr(env, "get_render_settings"), (
26
+ "Env must implement get_render_settings()"
27
+ )
28
+ self.env = env
29
+ self.renderer = GridworldRenderer(
30
+ screen_width=screen_width, screen_height=screen_height, fps=fps
31
+ )
32
+ self.renderer.set_env(env.get_render_settings())
33
+
34
+ def render(self, state, timestep):
35
+ rs: GridRenderState = self.env.get_render_state(state)
36
+ self.renderer.render(rs)
37
+
38
+ def render_pov(self, state, timestep):
39
+ """Render only the focused agent's point-of-view, filling the screen."""
40
+ rs: GridRenderState = self.env.get_render_state(state)
41
+ self.renderer.render_agent_view(rs)
42
+
43
+ def handle_event(self, event: pygame.event.Event) -> bool:
44
+ return self.renderer.handle_event(event)
45
+
46
+ def record_frame(self):
47
+ self.renderer.record_frame()
48
+
49
+ def save_video(self, file_name: str):
50
+ self.renderer.save_video(file_name)
51
+
52
+ def focus_agent(self, agent_id: int | None):
53
+ self.renderer.focus_agent(agent_id)
mapox/config.py ADDED
@@ -0,0 +1,92 @@
1
+ from typing import Literal
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
4
+
5
+ from mapox.envs.king_hill import KingHillConfig, KingHillEnv
6
+ from mapox.envs.grid_return import ReturnDiggingConfig, ReturnDiggingEnv
7
+ from mapox.envs.traveling_salesman import (
8
+ TravelingSalesmanConfig,
9
+ TravelingSalesmanEnv,
10
+ )
11
+ from mapox.envs.scouts import ScoutsConfig, ScoutsEnv
12
+ from mapox.client import GridworldClient
13
+
14
+ from mapox.client import EnvironmentClient
15
+ from mapox.environment import Environment
16
+ from mapox.wrappers.task_id_wrapper import TaskIdWrapper
17
+ from mapox.wrappers.multitask import MultiTaskWrapper
18
+ from mapox.wrappers.vector import VectorWrapper
19
+
20
+ type EnvironmentConfig = (
21
+ ReturnDiggingConfig
22
+ | TravelingSalesmanConfig
23
+ | ScoutsConfig
24
+ | KingHillConfig
25
+ )
26
+
27
+
28
+ class MultiTaskEnvConfig(BaseModel):
29
+ model_config = ConfigDict(extra="forbid", frozen=True)
30
+ num: int = 1
31
+ name: str
32
+ env: EnvironmentConfig = Field(discriminator="env_type")
33
+
34
+
35
+ class MultiTaskConfig(BaseModel):
36
+ model_config = ConfigDict(extra="forbid", frozen=True)
37
+ env_type: Literal["multi"] = "multi"
38
+ envs: tuple[MultiTaskEnvConfig, ...]
39
+
40
+ @field_validator("envs", mode="before")
41
+ @classmethod
42
+ def coerce_envs(cls, v):
43
+ # JSON gives list; accept list and turn into tuple
44
+ return tuple(v) if isinstance(v, list) else v
45
+
46
+
47
+ def create_env(
48
+ env_config: EnvironmentConfig | MultiTaskConfig,
49
+ length: int,
50
+ vec_count: int = 1,
51
+ env_name: str | None = None,
52
+ ) -> tuple[Environment, int]:
53
+ num_tasks = 1
54
+ if env_config.env_type == "multi" and env_name is not None:
55
+ num_tasks = len(env_config.envs)
56
+ for task_id, env_def in enumerate(env_config.envs):
57
+ if env_def.name == env_name:
58
+ return TaskIdWrapper(
59
+ create_env(env_def.env, length, vec_count=vec_count)[0], task_id
60
+ ), num_tasks
61
+ raise ValueError("Could not find environment matching env_name")
62
+
63
+ match env_config.env_type:
64
+ case "multi":
65
+ out_envs = []
66
+ out_env_names = []
67
+ num_tasks = len(env_config.envs)
68
+ for env_def in env_config.envs:
69
+ out_envs.append(create_env(env_def.env, length, env_def.num)[0])
70
+ out_env_names.append(env_def.name)
71
+
72
+ env = MultiTaskWrapper(tuple(out_envs), tuple(out_env_names))
73
+ case "return_digging":
74
+ env = ReturnDiggingEnv(env_config, length)
75
+ case "scouts":
76
+ env = ScoutsEnv(env_config, length)
77
+ case "traveling_salesman":
78
+ env = TravelingSalesmanEnv(env_config, length)
79
+ case "king_hill":
80
+ env = KingHillEnv(env_config, length)
81
+ case _:
82
+ raise ValueError(f"Unknown environment type: {env_config.env_type}")
83
+
84
+ if vec_count > 1:
85
+ env = VectorWrapper(env, vec_count)
86
+
87
+ return env, num_tasks
88
+
89
+
90
+ def create_client[State](env: Environment[State]) -> EnvironmentClient[State]:
91
+ # use the grid client for all environments for now
92
+ return GridworldClient(env)
mapox/environment.py ADDED
@@ -0,0 +1,52 @@
1
+ from abc import ABC, abstractmethod
2
+ from functools import cached_property
3
+ from typing import Any
4
+
5
+ import jax
6
+ from jax import Array
7
+
8
+ from mapox.timestep import TimeStep
9
+ from mapox.specs import ObservationSpec, ActionSpec
10
+ import enum
11
+
12
+
13
+ class StepType(enum.IntEnum):
14
+ FIRST = 0
15
+ MID = 1
16
+ LAST = 2
17
+
18
+
19
+ class Environment[State](ABC):
20
+ @abstractmethod
21
+ def reset(self, rng_key: Array) -> tuple[State, TimeStep]: ...
22
+
23
+ @abstractmethod
24
+ def step(
25
+ self, state: State, action: Array, rng_key: Array
26
+ ) -> tuple[State, TimeStep]: ...
27
+
28
+ @abstractmethod
29
+ def create_placeholder_logs(self) -> dict[str, Any]: ...
30
+
31
+ @abstractmethod
32
+ def create_logs(self, state) -> dict[str, Any]: ...
33
+
34
+ @cached_property
35
+ @abstractmethod
36
+ def observation_spec(self) -> ObservationSpec: ...
37
+
38
+ @cached_property
39
+ @abstractmethod
40
+ def action_spec(self) -> ActionSpec: ...
41
+
42
+ @property
43
+ @abstractmethod
44
+ def num_agents(self) -> int: ...
45
+
46
+ @property
47
+ @abstractmethod
48
+ def is_jittable(self) -> bool: ...
49
+
50
+ @property
51
+ def teams(self) -> jax.Array | None:
52
+ return None
mapox/envs/__init__.py ADDED
File without changes
@@ -0,0 +1,75 @@
1
+ from jax import numpy as jnp
2
+
3
+ from mapox.specs import ObservationSpec
4
+
5
+ NUM_TYPES = 15
6
+
7
+ # Unified tile ids across gridworld environments
8
+ TILE_EMPTY = 0 # empty space
9
+ TILE_WALL = 1 # permanent wall
10
+ TILE_DESTRUCTIBLE_WALL = 2 # destructible
11
+ TILE_FLAG = 3 # typical goal tile
12
+ TILE_FLAG_UNLOCKED = 4 # used for the scouting environment where the the flag gets made available for taking
13
+
14
+ # TILE_FLAG_BLUE_TEAM = 5
15
+ # TILE_FLAG_RED_TEAM = 6
16
+
17
+ # agents are observed like tiles
18
+ AGENT_GENERIC = 5 # typical agent
19
+ AGENT_SCOUT = 6 # scout agent (scout env)
20
+ AGENT_HARVESTER = 7 # harvester agent (scout env)
21
+
22
+ AGENT_KNIGHT = 8
23
+ AGENT_ARCHER = 9
24
+
25
+ TILE_DECOR_1 = 10
26
+ TILE_DECOR_2 = 11
27
+ TILE_DECOR_3 = 12
28
+ TILE_DECOR_4 = 13
29
+
30
+ TILE_ARROW = 14
31
+
32
+
33
+ # Actions
34
+ NUM_ACTIONS = 7
35
+
36
+ MOVE_UP = 0
37
+ MOVE_RIGHT = 1
38
+ MOVE_DOWN = 2
39
+ MOVE_LEFT = 3
40
+ STAY = 4
41
+ PRIMARY_ACTION = 5
42
+ DIG_ACTION = 6
43
+
44
+
45
+ DIRECTIONS = jnp.array([[0, 1], [1, 0], [0, -1], [-1, 0]], dtype=jnp.int32)
46
+
47
+
48
+ def make_obs_spec(width: int, height: int) -> ObservationSpec:
49
+ # CHANNELS:
50
+ # TILE
51
+ # DIRECTION
52
+ # TEAM ID
53
+ # HEALTH
54
+ obs_spec = ObservationSpec(
55
+ dtype=jnp.int8,
56
+ shape=(width, height, 4),
57
+ max_value=(
58
+ NUM_TYPES,
59
+ 5, # none, up, right, down, left,
60
+ 3, # none, red, blue
61
+ 3, # 0, 1, 2
62
+ ),
63
+ )
64
+
65
+ return obs_spec
66
+
67
+
68
+ def make_action_mask(actions: list[int], num_agents: int):
69
+ mask = [False] * NUM_ACTIONS
70
+
71
+ for action in actions:
72
+ mask[action] = True
73
+
74
+ mask_array = jnp.array(mask, jnp.bool)
75
+ return jnp.repeat(mask_array[None, :], num_agents, axis=0)
@@ -0,0 +1,389 @@
1
+ from functools import cached_property, partial
2
+ from typing import NamedTuple, Literal
3
+
4
+ import jax
5
+ from jax import numpy as jnp
6
+ import numpy as np
7
+ from pydantic import BaseModel, ConfigDict
8
+
9
+ from mapox.map_generator import (
10
+ fractal_noise,
11
+ generate_decor_tiles,
12
+ choose_positions,
13
+ )
14
+ from mapox.environment import Environment
15
+ from mapox.specs import DiscreteActionSpec, ObservationSpec
16
+ from mapox.timestep import TimeStep
17
+ from mapox.envs.renderer import GridRenderSettings, GridRenderState
18
+ import mapox.envs.constance as GW
19
+
20
+
21
+ class ReturnDiggingConfig(BaseModel):
22
+ model_config = ConfigDict(extra="forbid", frozen=True)
23
+ env_type: Literal["return_digging"] = "return_digging"
24
+
25
+ num_agents: int = 1
26
+ num_flags: int = 1
27
+
28
+ width: int = 40
29
+ height: int = 40
30
+ view_width: int = 11
31
+ view_height: int = 11
32
+
33
+ mapgen_threshold: float = 0.3
34
+ digging_timeout: int = 5
35
+ treasure_reward: float = 1.0
36
+
37
+ eval_map: bool = False
38
+
39
+
40
+ class ReturnDiggingState(NamedTuple):
41
+ agents_pos: jax.Array
42
+ agents_timeout: jax.Array
43
+ found_reward: jax.Array
44
+
45
+ time: jax.Array
46
+
47
+ map: jax.Array
48
+ spawn_pos: jax.Array
49
+ spawn_count: jax.Array
50
+
51
+ rewards: jax.Array
52
+
53
+
54
+ class ReturnDiggingEnv(Environment[ReturnDiggingState]):
55
+ def __init__(self, config: ReturnDiggingConfig, length: int) -> None:
56
+ super().__init__()
57
+
58
+ self._config = config
59
+ self._length = length
60
+ self._num_agents = config.num_agents
61
+ self.num_flags = config.num_flags
62
+
63
+ self.unpadded_width = config.width
64
+ self.unpadded_height = config.height
65
+
66
+ self.view_width = config.view_width
67
+ self.view_height = config.view_height
68
+ self.pad_width = self.view_width // 2
69
+ self.pad_height = self.view_height // 2
70
+
71
+ self.width = self.unpadded_width + self.pad_width
72
+ self.height = self.unpadded_height + self.pad_height
73
+
74
+ self.mapgen_threshold = config.mapgen_threshold
75
+ self.digging_timeout = config.digging_timeout
76
+ self.treasure_reward = config.treasure_reward
77
+
78
+ self._action_mask = GW.make_action_mask(
79
+ [
80
+ GW.MOVE_UP,
81
+ GW.MOVE_RIGHT,
82
+ GW.MOVE_DOWN,
83
+ GW.MOVE_LEFT,
84
+ ],
85
+ self.num_agents,
86
+ )
87
+
88
+ def _generate_map(self, rng_key):
89
+ walls_key, decor_key, rng_key = jax.random.split(rng_key, 3)
90
+ noise = fractal_noise(
91
+ self.unpadded_width, self.unpadded_height, [2, 4, 5, 8, 10], walls_key
92
+ )
93
+
94
+ tiles = generate_decor_tiles(
95
+ self.unpadded_width, self.unpadded_height, decor_key
96
+ )
97
+ tiles = jnp.where(noise > 0.05, jnp.int8(GW.TILE_DESTRUCTIBLE_WALL), tiles)
98
+
99
+ # get the empty tiles for spawning
100
+ x_spawns, y_spawns = jnp.where(
101
+ tiles == GW.TILE_EMPTY,
102
+ size=self.unpadded_width * self.unpadded_height,
103
+ fill_value=jnp.int8(-1),
104
+ )
105
+ spawn_count = jnp.sum(tiles == GW.TILE_EMPTY)
106
+
107
+ # pad the tiles
108
+ tiles = jnp.pad(
109
+ tiles,
110
+ pad_width=(
111
+ (self.pad_width, self.pad_width),
112
+ (self.pad_height, self.pad_height),
113
+ ),
114
+ mode="constant",
115
+ constant_values=GW.TILE_WALL,
116
+ )
117
+
118
+ # pad the empty tiles
119
+ y_spawns = y_spawns + self.pad_height
120
+ x_spawns = x_spawns + self.pad_width
121
+ spawn_pos = jnp.stack((x_spawns, y_spawns), axis=1)
122
+
123
+ return tiles, spawn_pos, spawn_count
124
+
125
+ def reset(self, rng_key: jax.Array) -> tuple[ReturnDiggingState, TimeStep]:
126
+ map_key, pos_key = jax.random.split(rng_key)
127
+
128
+ map, spawn_pos, spawn_count = self._generate_map(map_key)
129
+
130
+ unpadded_map = map[
131
+ self.pad_width : -self.pad_width, self.pad_height : -self.pad_height
132
+ ]
133
+
134
+ pos_x, pos_y = choose_positions(
135
+ unpadded_map,
136
+ self.num_flags + self.num_agents,
137
+ pos_key,
138
+ replace=False,
139
+ )
140
+
141
+ pos_x = pos_x + self.pad_width
142
+ pos_y = pos_y + self.pad_height
143
+ positions = jnp.stack((pos_x, pos_y), axis=1)
144
+ flag_pos = positions[: self.num_flags]
145
+ agents_pos = positions[self.num_flags :]
146
+
147
+ if self._config.eval_map:
148
+ o = map
149
+ map = map.at[36:45, 10:30].set(GW.TILE_DESTRUCTIBLE_WALL)
150
+ map = map.at[42:45, 17:23].set(o[42:45, 17:23])
151
+ agents_pos = agents_pos.at[0].set([44, 22])
152
+ map = map.at[43, 18].set(GW.TILE_FLAG)
153
+ else:
154
+ map = map.at[flag_pos[:, 0], flag_pos[:, 1]].set(GW.TILE_FLAG)
155
+
156
+ state = ReturnDiggingState(
157
+ map=map,
158
+ spawn_pos=spawn_pos,
159
+ spawn_count=spawn_count,
160
+ agents_pos=agents_pos,
161
+ agents_timeout=jnp.zeros((self.num_agents,), dtype=jnp.int32),
162
+ found_reward=jnp.zeros((self.num_agents,), dtype=jnp.bool_),
163
+ time=jnp.int32(0),
164
+ rewards=jnp.float32(0.0),
165
+ )
166
+
167
+ actions = jnp.zeros((self.num_agents,), dtype=jnp.int32)
168
+ rewards = jnp.zeros((self.num_agents,), dtype=jnp.float32)
169
+
170
+ return state, self.encode_observations(state, actions, rewards)
171
+
172
+ def load_map(self, map: str):
173
+ tiles = np.zeros((self.unpadded_width, self.unpadded_height), dtype=np.int8)
174
+
175
+ x = 0
176
+ y = self.unpadded_height
177
+
178
+ agent_positions = []
179
+ spawn_positions = []
180
+
181
+ for c in map:
182
+ if c == "\n":
183
+ x = 0
184
+ y -= 1
185
+ else:
186
+ match x:
187
+ case "x":
188
+ tiles[x, y] = GW.TILE_DESTRUCTIBLE_WALL
189
+ case "a":
190
+ agent_positions.append(
191
+ [self.pad_width + x, self.pad_height + y]
192
+ )
193
+ case "f":
194
+ tiles[x, y] = GW.TILE_FLAG
195
+ case _:
196
+ spawn_positions.append(
197
+ [self.pad_width + x, self.pad_height + y]
198
+ )
199
+
200
+ x += 1
201
+
202
+ tiles = jnp.asarray(tiles)
203
+ # pad the tiles
204
+ tiles = jnp.pad(
205
+ tiles,
206
+ pad_width=(
207
+ (self.pad_width, self.pad_width),
208
+ (self.pad_height, self.pad_height),
209
+ ),
210
+ mode="constant",
211
+ constant_values=GW.TILE_WALL,
212
+ )
213
+
214
+ state = ReturnDiggingState(
215
+ map=tiles,
216
+ spawn_pos=jnp.array(spawn_positions, jnp.int32),
217
+ spawn_count=jnp.int32(len(spawn_positions)),
218
+ agents_pos=jnp.arange(agent_positions, jnp.int32),
219
+ agents_timeout=jnp.zeros((self.num_agents,), dtype=jnp.int32),
220
+ found_reward=jnp.zeros((self.num_agents,), dtype=jnp.bool_),
221
+ time=jnp.int32(0),
222
+ rewards=jnp.float32(0.0),
223
+ )
224
+
225
+ actions = jnp.zeros((self.num_agents,), dtype=jnp.int32)
226
+ rewards = jnp.zeros((self.num_agents,), dtype=jnp.float32)
227
+
228
+ return state, self.encode_observations(state, actions, rewards)
229
+
230
+ @cached_property
231
+ def observation_spec(self) -> ObservationSpec:
232
+ return GW.make_obs_spec(self.view_width, self.view_height)
233
+
234
+ @cached_property
235
+ def action_spec(self) -> DiscreteActionSpec:
236
+ return DiscreteActionSpec(num_actions=GW.NUM_ACTIONS)
237
+
238
+ @property
239
+ def is_jittable(self) -> bool:
240
+ return True
241
+
242
+ @property
243
+ def num_agents(self) -> int:
244
+ return self._num_agents
245
+
246
+ def step(
247
+ self, state: ReturnDiggingState, action: jax.Array, rng_key: jax.Array
248
+ ) -> tuple[ReturnDiggingState, TimeStep]:
249
+ @partial(jax.vmap, in_axes=(0, 0, 0, 0), out_axes=(0, 0, 0, 0))
250
+ def _step_agent(local_position, timeout, local_action, random_position):
251
+ def _step_timeout(local_position, timeout, local_action, random_position):
252
+ return local_position, local_position, timeout - 1, 0.0
253
+
254
+ def _step_move(local_position, timeout, local_action, random_position):
255
+ target_pos = local_position + GW.DIRECTIONS[local_action]
256
+
257
+ new_tile = state.map[target_pos[0], target_pos[1]]
258
+
259
+ # don't move if we are moving into a wall
260
+ new_pos = jnp.where(
261
+ jnp.logical_or(
262
+ new_tile == GW.TILE_WALL, new_tile == GW.TILE_DESTRUCTIBLE_WALL
263
+ ),
264
+ local_position,
265
+ target_pos,
266
+ )
267
+
268
+ found_treasure = new_tile == GW.TILE_FLAG
269
+ reward = jnp.where(found_treasure, self.treasure_reward, 0.0)
270
+
271
+ # randomize position if the agent finds the reward
272
+ new_pos = jnp.where(found_treasure, random_position, new_pos)
273
+
274
+ # sets a timeout of the tile is dug
275
+ timeout = jnp.where(
276
+ new_tile == GW.TILE_DESTRUCTIBLE_WALL, self.digging_timeout, 0
277
+ )
278
+
279
+ return new_pos, target_pos, timeout, reward
280
+
281
+ return jax.lax.cond(
282
+ timeout > 0,
283
+ _step_timeout,
284
+ _step_move,
285
+ local_position,
286
+ timeout,
287
+ local_action,
288
+ random_position,
289
+ )
290
+
291
+ random_positions = state.spawn_pos[
292
+ jax.random.randint(
293
+ rng_key, (self._num_agents,), minval=0, maxval=state.spawn_count
294
+ )
295
+ ]
296
+ new_position, target_pos, timeout, rewards = _step_agent(
297
+ state.agents_pos, state.agents_timeout, action, random_positions
298
+ )
299
+
300
+ # dig actions
301
+ target_tiles = state.map[target_pos[:, 0], target_pos[:, 1]]
302
+ map = state.map.at[target_pos[:, 0], target_pos[:, 1]].set(
303
+ jnp.where(
304
+ target_tiles == GW.TILE_DESTRUCTIBLE_WALL, GW.TILE_EMPTY, target_tiles
305
+ )
306
+ )
307
+ # /dig actions
308
+
309
+ state = state._replace(
310
+ agents_pos=new_position,
311
+ agents_timeout=timeout,
312
+ found_reward=jnp.logical_or(state.found_reward, rewards),
313
+ time=state.time + 1,
314
+ rewards=state.rewards + jnp.mean(rewards),
315
+ map=map,
316
+ )
317
+
318
+ return state, self.encode_observations(state, action, rewards)
319
+
320
+ def _render_tiles(self, state: ReturnDiggingState):
321
+ tiles = state.map
322
+ tiles = tiles.at[state.agents_pos[:, 0], state.agents_pos[:, 1]].set(
323
+ GW.AGENT_GENERIC
324
+ )
325
+
326
+ directions = jnp.zeros_like(tiles, dtype=jnp.int8)
327
+ teams = jnp.zeros_like(tiles, dtype=jnp.int8)
328
+ health = jnp.zeros_like(tiles, dtype=jnp.int8)
329
+
330
+ return jnp.concatenate(
331
+ (
332
+ tiles[..., None],
333
+ directions[..., None],
334
+ teams[..., None],
335
+ health[..., None],
336
+ ),
337
+ axis=-1,
338
+ )
339
+
340
+ def encode_observations(
341
+ self, state: ReturnDiggingState, actions, rewards
342
+ ) -> TimeStep:
343
+ @partial(jax.vmap, in_axes=(None, 0))
344
+ def _encode_view(tiles, positions):
345
+ return jax.lax.dynamic_slice(
346
+ tiles,
347
+ (
348
+ positions[0] - self.view_width // 2,
349
+ positions[1] - self.view_height // 2,
350
+ 0,
351
+ ),
352
+ (self.view_width, self.view_height, self.observation_spec.shape[-1]),
353
+ )
354
+
355
+ tiles = self._render_tiles(state)
356
+ view = _encode_view(tiles, state.agents_pos)
357
+
358
+ time = jnp.repeat(state.time[None], self.num_agents, axis=0)
359
+
360
+ return TimeStep(
361
+ obs=view,
362
+ time=time,
363
+ last_action=actions,
364
+ last_reward=rewards,
365
+ action_mask=self._action_mask,
366
+ terminated=jnp.equal(time, self._length - 1),
367
+ )
368
+
369
+ def create_placeholder_logs(self):
370
+ return {"rewards": jnp.float32(0.0)}
371
+
372
+ def create_logs(self, state: ReturnDiggingState):
373
+ return {"rewards": state.rewards}
374
+
375
+ def get_render_state(self, state: ReturnDiggingState) -> GridRenderState:
376
+ tiles = self._render_tiles(state)
377
+
378
+ return GridRenderState(
379
+ tilemap=tiles,
380
+ agent_positions=state.agents_pos,
381
+ )
382
+
383
+ def get_render_settings(self) -> GridRenderSettings:
384
+ return GridRenderSettings(
385
+ tile_width=self.unpadded_width,
386
+ tile_height=self.unpadded_height,
387
+ view_width=self.view_width,
388
+ view_height=self.view_height,
389
+ )