dfa-gym 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dfa_gym/token_env.py ADDED
@@ -0,0 +1,571 @@
1
+ import jax
2
+ import chex
3
+ import numpy as np
4
+ import jax.numpy as jnp
5
+ from flax import struct
6
+ from enum import IntEnum
7
+ from functools import partial
8
+ from typing import Tuple, Dict
9
+ from dfa_gym import spaces
10
+ from dfa_gym.env import MultiAgentEnv, State
11
+
12
+
13
+ class Action(IntEnum):
14
+ DOWN = 0
15
+ RIGHT = 1
16
+ UP = 2
17
+ LEFT = 3
18
+ NOOP = 4
19
+
20
+ ACTION_MAP = jnp.array([
21
+ [ 1, 0], # DOWN
22
+ [ 0, 1], # RIGHT
23
+ [-1, 0], # UP
24
+ [ 0, -1], # LEFT
25
+ [ 0, 0], # NOOP
26
+ ])
27
+
28
+ @struct.dataclass
29
+ class TokenEnvState(State):
30
+ agent_positions: jax.Array
31
+ token_positions: jax.Array
32
+ wall_positions: jax.Array
33
+ is_wall_disabled: jax.Array
34
+ button_positions: jax.Array
35
+ is_alive: jax.Array
36
+ time: int
37
+
38
+ class TokenEnv(MultiAgentEnv):
39
+
40
+ def __init__(
41
+ self,
42
+ n_agents: int = 3,
43
+ n_tokens: int = 10,
44
+ n_token_repeat: int = 2,
45
+ grid_shape: Tuple[int, int] = (7, 7),
46
+ fixed_map_seed: int | None = None,
47
+ max_steps_in_episode: int = 100,
48
+ collision_reward: int | None = None,
49
+ black_death: bool = True,
50
+ layout: str | None = None
51
+ ) -> None:
52
+ super().__init__(num_agents=n_agents)
53
+ assert (grid_shape[0] * grid_shape[1]) >= (n_agents + n_tokens * n_token_repeat)
54
+ self.n_agents = n_agents
55
+ self.n_tokens = n_tokens
56
+ self.n_token_repeat = n_token_repeat
57
+ self.grid_shape = grid_shape
58
+ self.grid_shape_arr = jnp.array(self.grid_shape)
59
+ self.fixed_map_seed = fixed_map_seed
60
+ self.max_steps_in_episode = max_steps_in_episode
61
+ self.collision_reward = collision_reward
62
+ self.black_death = black_death
63
+ self.n_buttons = 0
64
+
65
+ self.agents = [f"agent_{i}" for i in range(self.n_agents)]
66
+
67
+ self.init_state = None
68
+ if layout is not None:
69
+ self.init_state = self.parse(layout)
70
+ self.num_agents = self.n_agents
71
+
72
+ channel_dim = 1
73
+ if self.init_state is not None: channel_dim += 2
74
+ if self.n_tokens > 0: channel_dim += self.n_tokens
75
+ if self.n_agents > 1: channel_dim += self.n_agents - 1
76
+ if self.n_buttons > 0: channel_dim += 3 * self.n_buttons
77
+ self.obs_shape = (channel_dim, *self.grid_shape)
78
+
79
+ self.action_spaces = {
80
+ agent: spaces.Discrete(len(ACTION_MAP))
81
+ for agent in self.agents
82
+ }
83
+ self.observation_spaces = {
84
+ agent: spaces.Box(low=0, high=1, shape=self.obs_shape, dtype=jnp.uint8)
85
+ for agent in self.agents
86
+ }
87
+
88
+ @partial(jax.jit, static_argnums=(0,))
89
+ def reset(
90
+ self,
91
+ key: chex.PRNGKey
92
+ ) -> Tuple[Dict[str, chex.Array], TokenEnvState]:
93
+ state = self.init_state
94
+ if state is None: state = self.sample_init_state(key)
95
+ obs = self.get_obs(state=state)
96
+ return obs, state
97
+
98
+ @partial(jax.jit, static_argnums=(0,))
99
+ def step_env(
100
+ self,
101
+ key: chex.PRNGKey,
102
+ state: TokenEnvState,
103
+ actions: Dict[str, chex.Array]
104
+ ) -> Tuple[Dict[str, chex.Array], TokenEnvState, Dict[str, float], Dict[str, bool], Dict]:
105
+
106
+ _actions = jnp.array([actions[agent] for agent in self.agents])
107
+
108
+ # Move agents
109
+ def move_agent(pos, a):
110
+ return (pos + ACTION_MAP[a]) % self.grid_shape_arr
111
+ new_agent_pos = jax.vmap(move_agent, in_axes=(0, 0))(state.agent_positions, _actions)
112
+ new_agent_pos = jnp.where(state.is_alive[:, None], new_agent_pos, state.agent_positions)
113
+
114
+ if self.init_state is not None:
115
+ # Handle wall collisions
116
+ def compute_wall_collisions(pos, wall_positions, is_wall_disabled):
117
+ return jnp.any(
118
+ jnp.logical_and(
119
+ jnp.logical_not(is_wall_disabled),
120
+ jnp.all(
121
+ pos[None, :] == wall_positions # [N, 2]
122
+ , axis=-1), # [N,]
123
+ )
124
+ , axis=-1) # [1,]
125
+ wall_collisions = jax.vmap(compute_wall_collisions, in_axes=(0, None, None))(new_agent_pos, state.wall_positions, state.is_wall_disabled)
126
+ new_agent_pos = jnp.where(wall_collisions[:, None], state.agent_positions, new_agent_pos)
127
+
128
+ # Handle collisions
129
+ # TODO: When collision_reward is not None, there might be unintended behavior.
130
+ # +-------+-------+-------+-------+-------+-------+-------+
131
+ # | 0 | # | . | # | . | # | 2 |
132
+ # +-------+-------+-------+-------+-------+-------+-------+
133
+ # | . | # | . | # | 1 | # | . |
134
+ # +-------+-------+-------+-------+-------+-------+-------+
135
+ # | 4 | # | 9 | 5 | 3 | # | 7 |
136
+ # +-------+-------+-------+-------+-------+-------+-------+
137
+ # | . | . | 8 | # | . | # | 7 |
138
+ # +-------+-------+-------+-------+-------+-------+-------+
139
+ # | 6 | # | 2 | # | 9 | # | 3 |
140
+ # +-------+-------+-------+-------+-------+-------+-------+
141
+ # | 8 | # | 0 | # | . | # | A_1,5 |
142
+ # +-------+-------+-------+-------+-------+-------+-------+
143
+ # | . | # | 1 | # | 4 | A_2 | A_0,6 |
144
+ # +-------+-------+-------+-------+-------+-------+-------+
145
+ # Action for agent_0
146
+ # 3
147
+ # Action for agent_1
148
+ # 0
149
+ # Action for agent_2
150
+ # 1
151
+ # Gives
152
+ # {'agent_0': Array(-100., dtype=float32), 'agent_1': Array(-100., dtype=float32), 'agent_2': Array(-100., dtype=float32)}
153
+ # {'__all__': Array(True, dtype=bool), 'agent_0': Array(True, dtype=bool), 'agent_1': Array(True, dtype=bool), 'agent_2': Array(True, dtype=bool)}
154
+ def compute_collisions(mask):
155
+ positions = jnp.where(mask[:, None], state.agent_positions, new_agent_pos)
156
+
157
+ collision_grid = jnp.zeros(self.grid_shape)
158
+ collision_grid, _ = jax.lax.scan(
159
+ lambda grid, pos: (grid.at[pos[0], pos[1]].add(1), None),
160
+ collision_grid,
161
+ positions,
162
+ )
163
+
164
+ collision_mask = collision_grid > 1
165
+
166
+ collisions = jax.vmap(lambda p: collision_mask[p[0], p[1]])(positions)
167
+ return jnp.logical_and(state.is_alive, collisions)
168
+
169
+ collisions = jax.lax.while_loop(
170
+ lambda mask: jnp.any(compute_collisions(mask)),
171
+ lambda mask: jnp.logical_or(mask, compute_collisions(mask)),
172
+ jnp.zeros((self.n_agents,), dtype=bool)
173
+ )
174
+
175
+ if self.collision_reward is None:
176
+ new_agent_pos = jnp.where(collisions[:, None], state.agent_positions, new_agent_pos)
177
+ collisions = jnp.full(collisions.shape, False)
178
+
179
+ # Handle swaps
180
+ def compute_swaps(original_positions, new_positions):
181
+ original_pos_expanded = jnp.expand_dims(original_positions, axis=0)
182
+ new_pos_expanded = jnp.expand_dims(new_positions, axis=1)
183
+
184
+ swap_mask = (original_pos_expanded == new_pos_expanded).all(axis=-1)
185
+ swap_mask = jnp.fill_diagonal(swap_mask, False, inplace=False)
186
+
187
+ swap_pairs = jnp.logical_and(swap_mask, swap_mask.T)
188
+
189
+ swaps = jnp.any(swap_pairs, axis=0)
190
+ return swaps
191
+
192
+ swaps = compute_swaps(state.agent_positions, new_agent_pos)
193
+ new_agent_pos = jnp.where(swaps[:, None], state.agent_positions, new_agent_pos)
194
+
195
+ _rewards = jnp.zeros((self.n_agents,), dtype=jnp.float32)
196
+ if self.collision_reward is not None:
197
+ _rewards = jnp.where(jnp.logical_and(state.is_alive, collisions), self.collision_reward, _rewards)
198
+ rewards = {agent: _rewards[i] for i, agent in enumerate(self.agents)}
199
+
200
+ is_wall_disabled = jnp.empty((0, 2), dtype=bool)
201
+ if self.init_state is not None:
202
+ is_wall_disabled = self.compute_disabled_walls(new_agent_pos, state.wall_positions, state.button_positions)
203
+
204
+ new_state = TokenEnvState(
205
+ agent_positions=new_agent_pos,
206
+ token_positions=state.token_positions,
207
+ wall_positions=state.wall_positions,
208
+ is_wall_disabled=is_wall_disabled,
209
+ button_positions=state.button_positions,
210
+ is_alive=jnp.logical_and(state.is_alive, jnp.logical_not(collisions)),
211
+ time=state.time + 1
212
+ )
213
+
214
+ _dones = jnp.logical_or(collisions, new_state.time >= self.max_steps_in_episode)
215
+ dones = {a: _dones[i] for i, a in enumerate(self.agents)}
216
+ dones.update({"__all__": jnp.all(_dones)})
217
+
218
+ obs = self.get_obs(state=new_state)
219
+ info = {}
220
+
221
+ return obs, new_state, rewards, dones, info
222
+
223
+ @partial(jax.jit, static_argnums=(0,))
224
+ def get_obs(
225
+ self,
226
+ state: TokenEnvState
227
+ ) -> Dict[str, chex.Array]:
228
+
229
+ def obs_for_agent(i):
230
+ base = jnp.zeros(self.obs_shape, dtype=jnp.uint8)
231
+ # ref = self.grid_shape_arr // 2
232
+ ref = jnp.array([0, 0])
233
+ offset = ref - state.agent_positions[i]
234
+ idx_offset = 0
235
+ b = base
236
+
237
+ def place_agent(val):
238
+ rel = (state.agent_positions[i] + offset) % self.grid_shape_arr
239
+ return val.at[idx_offset, rel[0], rel[1]].set(1) # Is agent?
240
+ b = place_agent(b)
241
+ idx_offset += 1
242
+
243
+ if self.init_state is not None:
244
+ def place_wall(val):
245
+ rel = (state.wall_positions + offset) % self.grid_shape_arr
246
+ return val.at[
247
+ idx_offset, rel[:, 0], rel[:, 1]
248
+ ].set(1).at[ # Is wall?
249
+ idx_offset + 1, rel[:, 0], rel[:, 1]
250
+ ].set(jnp.logical_not(state.is_wall_disabled).astype(jnp.uint8)) # Is wall blocking?
251
+ b = place_wall(b)
252
+ idx_offset += 2
253
+
254
+ if self.n_tokens > 0:
255
+ def place_token(token_idx, val):
256
+ rel = (state.token_positions[token_idx] + offset) % self.grid_shape_arr
257
+ return val.at[idx_offset + token_idx, rel[:, 0], rel[:, 1]].set(1) # Is token?
258
+ b = jax.lax.fori_loop(0, self.n_tokens, place_token, b)
259
+ idx_offset += self.n_tokens
260
+
261
+ if self.n_agents > 1:
262
+ def place_other(other_idx, val):
263
+ rel = (state.agent_positions[other_idx + (other_idx >= i)] + offset) % self.grid_shape_arr
264
+ return val.at[idx_offset + other_idx, rel[0], rel[1]].set(1) # Is other agent?
265
+ b = jax.lax.fori_loop(0, self.n_agents - 1, place_other, b)
266
+ idx_offset += self.n_agents - 1
267
+
268
+ if self.n_buttons > 0:
269
+ def place_button(button_idx, val):
270
+ is_door = jnp.any(
271
+ jnp.all(
272
+ state.button_positions[button_idx][:, None, :] == state.wall_positions[None, :, :]
273
+ , axis=-1)
274
+ , axis=-1) # Buttons are considered doors if they are in a wall.
275
+ rel = (state.button_positions[button_idx] + offset) % self.grid_shape_arr
276
+ return val.at[
277
+ idx_offset + 3 * button_idx, rel[:, 0], rel[:, 1]
278
+ ].set(1).at[ # Is button-door pair?
279
+ idx_offset + 3 * button_idx + 1, rel[:, 0], rel[:, 1]
280
+ ].set(jnp.logical_not(is_door).astype(jnp.uint8)).at[ # Is button?
281
+ idx_offset + 3 * button_idx + 2, rel[:, 0], rel[:, 1]
282
+ ].set(is_door.astype(jnp.uint8)) # Is door?
283
+
284
+ b = jax.lax.fori_loop(0, self.n_buttons, place_button, b)
285
+
286
+ return jnp.where(jnp.logical_or(jnp.logical_not(self.black_death), state.is_alive[i]), b, base)
287
+
288
+ obs = jax.vmap(obs_for_agent)(jnp.arange(self.n_agents))
289
+ return {agent: obs[i] for i, agent in enumerate(self.agents)}
290
+
291
+ @partial(jax.jit, static_argnums=(0,))
292
+ def label_f(self, state: TokenEnvState) -> Dict[str, int]:
293
+
294
+ diffs = state.agent_positions[:, None, None, :] - state.token_positions[None, :, :, :]
295
+ matches = jnp.all(diffs == 0, axis=-1)
296
+ matches_any = jnp.any(matches, axis=-1)
297
+
298
+ has_match = jnp.any(matches_any, axis=1)
299
+ token_idx = jnp.argmax(matches_any, axis=1)
300
+
301
+ agent_token_matches = jnp.where(jnp.logical_and(has_match, state.is_alive), token_idx, -1)
302
+
303
+ return {self.agents[agent_idx]: token_idx for agent_idx, token_idx in enumerate(agent_token_matches)}
304
+
305
+ @partial(jax.jit, static_argnums=(0,))
306
+ def sample_init_state(
307
+ self,
308
+ key: chex.PRNGKey
309
+ ) -> Tuple[Dict[str, chex.Array], TokenEnvState]:
310
+ if self.fixed_map_seed is not None:
311
+ key = jax.random.PRNGKey(self.fixed_map_seed)
312
+
313
+ grid_points = jnp.stack(jnp.meshgrid(jnp.arange(self.grid_shape[0]), jnp.arange(self.grid_shape[1])), -1)
314
+ grid_flat = grid_points.reshape(-1, 2)
315
+
316
+ key, subkey = jax.random.split(key)
317
+ perm = jax.random.permutation(subkey, grid_flat.shape[0])
318
+
319
+ agent_positions = grid_flat[perm][:self.n_agents]
320
+ token_positions = grid_flat[perm][self.n_agents: self.n_agents + self.n_tokens * self.n_token_repeat].reshape(self.n_tokens, self.n_token_repeat, 2)
321
+
322
+ return TokenEnvState(
323
+ agent_positions=agent_positions,
324
+ token_positions=token_positions,
325
+ wall_positions=jnp.empty((0, 2), dtype=jnp.int32),
326
+ is_wall_disabled=jnp.empty((0, 2), dtype=bool),
327
+ button_positions=jnp.empty((0, 2), dtype=jnp.int32),
328
+ is_alive=jnp.ones((self.n_agents,), dtype=bool),
329
+ time=0
330
+ )
331
+
332
+ @partial(jax.jit, static_argnums=(0,))
333
+ def compute_disabled_walls(self, agent_positions, wall_positions, button_positions):
334
+ def _compute_on_buttons(button_pos, agent_positions):
335
+ return jnp.any(
336
+ jnp.any(
337
+ jnp.all(
338
+ button_pos[:, None, :] == agent_positions[None, :, :] # [M, N, 2]
339
+ , axis=-1) # [M, N,]
340
+ , axis=-1) # [M,]
341
+ , axis=-1) # [1,]
342
+ on_buttons = jax.vmap(_compute_on_buttons, in_axes=(0, None))(button_positions, agent_positions)
343
+ def _compute_disabled_walls(wall_pos, on_buttons, button_positions):
344
+ # Compare each wall_pos to each button coordinate
345
+ eq = jnp.all(button_positions == wall_pos, axis=-1) # (n_buttons, n_button_repeat)
346
+ # A wall is disabled if *any* matching button is pressed
347
+ return jnp.any(jnp.logical_and(on_buttons, jnp.any(eq, axis=-1)))
348
+ return jax.vmap(_compute_disabled_walls, in_axes=(0, None, None))(wall_positions, on_buttons, button_positions)
349
+
350
+ def parse(self, layout: str) -> TokenEnvState:
351
+ # Example layout:
352
+ # [ 8 ][ ][ ][ ][ ][ ][ ][ # ][ 0 ][ ][ ][ 1 ]
353
+ # [ ][ ][ ][ ][ ][ ][ ][ # ][ ][ ][ ][ ]
354
+ # [ ][ ][ b ][ ][ ][ ][ ][ # ][ ][ ][ ][ ]
355
+ # [ ][ ][ ][ ][ ][ ][ ][ # ][ 3 ][ ][ ][ 2 ]
356
+ # [ ][ ][ ][ ][ ][ ][ ][ # ][ # ][ # ][#,a][ # ]
357
+ # [ A ][ ][ ][ ][ ][ ][ ][ ][ ][ ][ ][ ]
358
+ # [ B ][ ][ ][ ][ ][ ][ ][ ][ ][ ][ ][ ]
359
+ # [ ][ ][ ][ ][ ][ ][ ][ # ][ # ][ # ][#,b][ # ]
360
+ # [ ][ ][ ][ ][ ][ ][ ][ # ][ 4 ][ ][ ][ 5 ]
361
+ # [ ][ ][ a ][ ][ ][ ][ ][ # ][ ][ ][ ][ ]
362
+ # [ ][ ][ ][ ][ ][ ][ ][ # ][ ][ ][ ][ ]
363
+ # [ 9 ][ ][ ][ ][ ][ ][ ][ # ][ 7 ][ ][ ][ 6 ]
364
+ # where each [] indicates a cell, uppercase letters indicate agents,
365
+ # e.g., A and B, and lower case letters indicate "sync" points which
366
+ # are more like doors and buttons that open the doors, eg [#,a] is a
367
+ # door that is open if there is an agent on the cell [ a ] and closed;
368
+ # otherwise, and cells with # indicate walls.
369
+
370
+ # --- parse into a 2D list of cell contents (strings inside brackets) ---
371
+ lines = [ln.strip() for ln in layout.strip().splitlines() if ln.strip()]
372
+ rows: list[list[str]] = []
373
+
374
+ for ln in lines:
375
+ cells = []
376
+ i = 0
377
+ while i < len(ln):
378
+ if ln[i] == "[":
379
+ j = ln.find("]", i + 1)
380
+ if j == -1:
381
+ raise ValueError("Malformed layout: missing closing ']' in a row.")
382
+ cells.append(ln[i + 1:j].strip())
383
+ i = j + 1
384
+ else:
385
+ i += 1
386
+ if cells:
387
+ rows.append(cells)
388
+
389
+ if not rows:
390
+ raise ValueError("Parsed layout is empty.")
391
+
392
+ H = len(rows)
393
+ W = len(rows[0])
394
+ if any(len(r) != W for r in rows):
395
+ raise ValueError("All rows in the layout must have the same number of cells.")
396
+
397
+ # --- collect info ---
398
+ wall_positions: list[tuple[int, int]] = []
399
+ agent_positions: dict[int, tuple[int, int]] = {}
400
+ token_positions: dict[int, list[tuple[int, int]]] = {}
401
+ button_positions: dict[int, list[tuple[int, int]]] = {}
402
+ max_token_id = -1
403
+ max_button_id = -1
404
+
405
+ def _is_wall(c: str) -> bool:
406
+ return "#" == c
407
+
408
+ def _is_agent(c: str) -> bool:
409
+ is_agent_in_cell = [
410
+ a == c
411
+ for a in [chr(ord("A") + i) for i in range(ord("Z") - ord("A") + 1)]
412
+ ]
413
+ return sum(is_agent_in_cell) > 0
414
+
415
+ def _get_agent_idx(c: str) -> bool:
416
+ is_agent_in_cell = [
417
+ a == c
418
+ for a in [chr(ord("A") + i) for i in range(ord("Z") - ord("A") + 1)]
419
+ ]
420
+ return np.argmax(is_agent_in_cell)
421
+
422
+ def _is_token(c: str) -> bool:
423
+ try:
424
+ int(c)
425
+ return True
426
+ except ValueError:
427
+ return False
428
+
429
+ def _is_button(c: str) -> bool:
430
+ is_button_in_cell = [
431
+ a == c
432
+ for a in [chr(ord("a") + i) for i in range(ord("z") - ord("a") + 1)]
433
+ ]
434
+ return sum(is_button_in_cell) > 0
435
+
436
+ def _get_button_idx(c: str) -> bool:
437
+ is_button_in_cell = [
438
+ a == c
439
+ for a in [chr(ord("a") + i) for i in range(ord("z") - ord("a") + 1)]
440
+ ]
441
+ return np.argmax(is_button_in_cell)
442
+
443
+ for r in range(H):
444
+ for c in range(W):
445
+ cell = rows[r][c]
446
+ has_wall = False
447
+ has_agent = False
448
+ has_token = False
449
+ has_button = False
450
+ for content in cell.split(","):
451
+
452
+ if _is_wall(content):
453
+ if has_wall:
454
+ raise ValueError(f"One wall per cell.")
455
+ has_wall = True
456
+ wall_positions.append((r, c))
457
+
458
+ if _is_agent(content):
459
+ if has_agent:
460
+ raise ValueError(f"One agent per cell.")
461
+ has_agent = True
462
+ idx = _get_agent_idx(content)
463
+ if idx in agent_positions:
464
+ raise ValueError(f"Duplicate placement for agent '{content}'.")
465
+ agent_positions[idx] = (r, c)
466
+
467
+ if _is_token(content):
468
+ if has_token:
469
+ raise ValueError(f"One token per cell.")
470
+ has_token = True
471
+ tok_id = int(content)
472
+ max_token_id = max(max_token_id, tok_id)
473
+ token_positions.setdefault(tok_id, []).append((r, c))
474
+
475
+ if _is_button(content):
476
+ if has_button:
477
+ raise ValueError(f"One button per cell.")
478
+ has_button = True
479
+ idx = _get_button_idx(content)
480
+ max_button_id = max(max_button_id, idx)
481
+ button_positions.setdefault(idx, []).append((r, c))
482
+
483
+ assert not (has_wall and has_agent)
484
+ assert not (has_wall and has_token)
485
+ assert not (has_token and has_button)
486
+
487
+ # --- override environment settings ---
488
+ self.grid_shape = (H, W)
489
+ self.grid_shape_arr = jnp.array(self.grid_shape)
490
+
491
+ self.n_agents = len(agent_positions)
492
+ self.agents = [f"agent_{i}" for i in range(self.n_agents)]
493
+
494
+ self.n_tokens = max_token_id + 1 if max_token_id >= 0 else 0
495
+ self.n_token_repeat = max((len(v) for v in token_positions.values()), default=0)
496
+ token_positions_np = np.full((self.n_tokens, self.n_token_repeat, 2), -1, dtype=np.int32)
497
+ for tid in range(self.n_tokens):
498
+ coords = token_positions.get(tid, [])
499
+ for k, (r, c) in enumerate(coords[: self.n_token_repeat]):
500
+ token_positions_np[tid, k] = (r, c)
501
+
502
+ agent_positions_np = np.full((self.n_agents, 2), -1, dtype=np.int32)
503
+ for idx, pos in agent_positions.items():
504
+ agent_positions_np[idx] = pos
505
+
506
+ wall_positions_np = np.array(wall_positions, dtype=np.int32) if wall_positions else np.empty((0, 2), dtype=np.int32)
507
+
508
+ self.n_buttons = max_button_id + 1 if max_button_id >= 0 else 0
509
+ self.n_button_repeat = max((len(v) for v in button_positions.values()), default=0)
510
+ button_positions_np = np.full((self.n_buttons, self.n_button_repeat, 2), -1, dtype=np.int32)
511
+ for bid in range(self.n_buttons):
512
+ coords = button_positions.get(bid, [])
513
+ for k, (r, c) in enumerate(coords[: self.n_button_repeat]):
514
+ button_positions_np[bid, k] = (r, c)
515
+
516
+ agent_positions_jnp = jnp.array(agent_positions_np)
517
+ wall_positions_jnp = jnp.array(wall_positions_np)
518
+ button_positions_jnp = jnp.array(button_positions_np)
519
+
520
+ # --- return state ---
521
+ return TokenEnvState(
522
+ agent_positions=agent_positions_jnp,
523
+ token_positions=jnp.array(token_positions_np),
524
+ wall_positions=wall_positions_jnp,
525
+ is_wall_disabled=self.compute_disabled_walls(agent_positions_jnp, wall_positions_jnp, button_positions_jnp),
526
+ button_positions=button_positions_jnp,
527
+ is_alive=jnp.ones((self.n_agents,), dtype=bool),
528
+ time=0,
529
+ )
530
+
531
+ def render(self, state: TokenEnvState):
532
+ empty_cell = "."
533
+ wall_cell = "#"
534
+ grid = np.full(self.grid_shape, empty_cell, dtype=object)
535
+
536
+ for disabled, pos in zip(state.is_wall_disabled, state.wall_positions):
537
+ if not disabled:
538
+ grid[pos[0], pos[1]] = f"{wall_cell}"
539
+
540
+ for token, positions in enumerate(state.token_positions):
541
+ for pos in positions:
542
+ grid[pos[0], pos[1]] = f"{token}"
543
+
544
+ for button, positions in enumerate(state.button_positions):
545
+ for pos in positions:
546
+ current = grid[pos[0], pos[1]]
547
+ if current == empty_cell:
548
+ grid[pos[0], pos[1]] = f"b_{button}"
549
+ else:
550
+ grid[pos[0], pos[1]] = f"{current},b_{button}"
551
+
552
+ for agent in range(self.n_agents):
553
+ pos = state.agent_positions[agent]
554
+ current = grid[pos[0], pos[1]]
555
+ if current == empty_cell:
556
+ grid[pos[0], pos[1]] = f"A_{agent}"
557
+ else:
558
+ grid[pos[0], pos[1]] = f"A_{agent},{current}"
559
+
560
+ max_width = max(len(str(cell)) for row in grid for cell in row)
561
+
562
+ out = ""
563
+ h_line = "+" + "+".join(["-" * (max_width + 2) for _ in range(self.grid_shape[1])]) + "+"
564
+ out += h_line + "\n"
565
+ for row in grid:
566
+ row_str = "| " + " | ".join(f"{str(cell):<{max_width}}" for cell in row) + " |"
567
+ out += row_str + "\n"
568
+ out += h_line + "\n"
569
+
570
+ print(out)
571
+