dfa-gym 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dfa_gym/__init__.py +5 -15
- dfa_gym/dfa_bisim_env.py +121 -0
- dfa_gym/dfa_wrapper.py +185 -52
- dfa_gym/env.py +168 -0
- dfa_gym/maps/2buttons_2agents.pdf +0 -0
- dfa_gym/maps/2rooms_2agents.pdf +0 -0
- dfa_gym/maps/4buttons_4agents.pdf +0 -0
- dfa_gym/maps/4rooms_4agents.pdf +0 -0
- dfa_gym/robot.png +0 -0
- dfa_gym/spaces.py +156 -0
- dfa_gym/token_env.py +571 -0
- dfa_gym/utils.py +266 -0
- dfa_gym-0.2.0.dist-info/METADATA +93 -0
- dfa_gym-0.2.0.dist-info/RECORD +16 -0
- {dfa_gym-0.1.0.dist-info → dfa_gym-0.2.0.dist-info}/WHEEL +1 -1
- dfa_gym/dfa_env.py +0 -45
- dfa_gym-0.1.0.dist-info/METADATA +0 -11
- dfa_gym-0.1.0.dist-info/RECORD +0 -7
- {dfa_gym-0.1.0.dist-info → dfa_gym-0.2.0.dist-info}/licenses/LICENSE +0 -0
dfa_gym/token_env.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import chex
|
|
3
|
+
import numpy as np
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from flax import struct
|
|
6
|
+
from enum import IntEnum
|
|
7
|
+
from functools import partial
|
|
8
|
+
from typing import Tuple, Dict
|
|
9
|
+
from dfa_gym import spaces
|
|
10
|
+
from dfa_gym.env import MultiAgentEnv, State
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Action(IntEnum):
|
|
14
|
+
DOWN = 0
|
|
15
|
+
RIGHT = 1
|
|
16
|
+
UP = 2
|
|
17
|
+
LEFT = 3
|
|
18
|
+
NOOP = 4
|
|
19
|
+
|
|
20
|
+
ACTION_MAP = jnp.array([
|
|
21
|
+
[ 1, 0], # DOWN
|
|
22
|
+
[ 0, 1], # RIGHT
|
|
23
|
+
[-1, 0], # UP
|
|
24
|
+
[ 0, -1], # LEFT
|
|
25
|
+
[ 0, 0], # NOOP
|
|
26
|
+
])
|
|
27
|
+
|
|
28
|
+
@struct.dataclass
|
|
29
|
+
class TokenEnvState(State):
|
|
30
|
+
agent_positions: jax.Array
|
|
31
|
+
token_positions: jax.Array
|
|
32
|
+
wall_positions: jax.Array
|
|
33
|
+
is_wall_disabled: jax.Array
|
|
34
|
+
button_positions: jax.Array
|
|
35
|
+
is_alive: jax.Array
|
|
36
|
+
time: int
|
|
37
|
+
|
|
38
|
+
class TokenEnv(MultiAgentEnv):
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
n_agents: int = 3,
|
|
43
|
+
n_tokens: int = 10,
|
|
44
|
+
n_token_repeat: int = 2,
|
|
45
|
+
grid_shape: Tuple[int, int] = (7, 7),
|
|
46
|
+
fixed_map_seed: int | None = None,
|
|
47
|
+
max_steps_in_episode: int = 100,
|
|
48
|
+
collision_reward: int | None = None,
|
|
49
|
+
black_death: bool = True,
|
|
50
|
+
layout: str | None = None
|
|
51
|
+
) -> None:
|
|
52
|
+
super().__init__(num_agents=n_agents)
|
|
53
|
+
assert (grid_shape[0] * grid_shape[1]) >= (n_agents + n_tokens * n_token_repeat)
|
|
54
|
+
self.n_agents = n_agents
|
|
55
|
+
self.n_tokens = n_tokens
|
|
56
|
+
self.n_token_repeat = n_token_repeat
|
|
57
|
+
self.grid_shape = grid_shape
|
|
58
|
+
self.grid_shape_arr = jnp.array(self.grid_shape)
|
|
59
|
+
self.fixed_map_seed = fixed_map_seed
|
|
60
|
+
self.max_steps_in_episode = max_steps_in_episode
|
|
61
|
+
self.collision_reward = collision_reward
|
|
62
|
+
self.black_death = black_death
|
|
63
|
+
self.n_buttons = 0
|
|
64
|
+
|
|
65
|
+
self.agents = [f"agent_{i}" for i in range(self.n_agents)]
|
|
66
|
+
|
|
67
|
+
self.init_state = None
|
|
68
|
+
if layout is not None:
|
|
69
|
+
self.init_state = self.parse(layout)
|
|
70
|
+
self.num_agents = self.n_agents
|
|
71
|
+
|
|
72
|
+
channel_dim = 1
|
|
73
|
+
if self.init_state is not None: channel_dim += 2
|
|
74
|
+
if self.n_tokens > 0: channel_dim += self.n_tokens
|
|
75
|
+
if self.n_agents > 1: channel_dim += self.n_agents - 1
|
|
76
|
+
if self.n_buttons > 0: channel_dim += 3 * self.n_buttons
|
|
77
|
+
self.obs_shape = (channel_dim, *self.grid_shape)
|
|
78
|
+
|
|
79
|
+
self.action_spaces = {
|
|
80
|
+
agent: spaces.Discrete(len(ACTION_MAP))
|
|
81
|
+
for agent in self.agents
|
|
82
|
+
}
|
|
83
|
+
self.observation_spaces = {
|
|
84
|
+
agent: spaces.Box(low=0, high=1, shape=self.obs_shape, dtype=jnp.uint8)
|
|
85
|
+
for agent in self.agents
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
89
|
+
def reset(
|
|
90
|
+
self,
|
|
91
|
+
key: chex.PRNGKey
|
|
92
|
+
) -> Tuple[Dict[str, chex.Array], TokenEnvState]:
|
|
93
|
+
state = self.init_state
|
|
94
|
+
if state is None: state = self.sample_init_state(key)
|
|
95
|
+
obs = self.get_obs(state=state)
|
|
96
|
+
return obs, state
|
|
97
|
+
|
|
98
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
99
|
+
def step_env(
|
|
100
|
+
self,
|
|
101
|
+
key: chex.PRNGKey,
|
|
102
|
+
state: TokenEnvState,
|
|
103
|
+
actions: Dict[str, chex.Array]
|
|
104
|
+
) -> Tuple[Dict[str, chex.Array], TokenEnvState, Dict[str, float], Dict[str, bool], Dict]:
|
|
105
|
+
|
|
106
|
+
_actions = jnp.array([actions[agent] for agent in self.agents])
|
|
107
|
+
|
|
108
|
+
# Move agents
|
|
109
|
+
def move_agent(pos, a):
|
|
110
|
+
return (pos + ACTION_MAP[a]) % self.grid_shape_arr
|
|
111
|
+
new_agent_pos = jax.vmap(move_agent, in_axes=(0, 0))(state.agent_positions, _actions)
|
|
112
|
+
new_agent_pos = jnp.where(state.is_alive[:, None], new_agent_pos, state.agent_positions)
|
|
113
|
+
|
|
114
|
+
if self.init_state is not None:
|
|
115
|
+
# Handle wall collisions
|
|
116
|
+
def compute_wall_collisions(pos, wall_positions, is_wall_disabled):
|
|
117
|
+
return jnp.any(
|
|
118
|
+
jnp.logical_and(
|
|
119
|
+
jnp.logical_not(is_wall_disabled),
|
|
120
|
+
jnp.all(
|
|
121
|
+
pos[None, :] == wall_positions # [N, 2]
|
|
122
|
+
, axis=-1), # [N,]
|
|
123
|
+
)
|
|
124
|
+
, axis=-1) # [1,]
|
|
125
|
+
wall_collisions = jax.vmap(compute_wall_collisions, in_axes=(0, None, None))(new_agent_pos, state.wall_positions, state.is_wall_disabled)
|
|
126
|
+
new_agent_pos = jnp.where(wall_collisions[:, None], state.agent_positions, new_agent_pos)
|
|
127
|
+
|
|
128
|
+
# Handle collisions
|
|
129
|
+
# TODO: When collision_reward is not None, there might be unintended behavior.
|
|
130
|
+
# +-------+-------+-------+-------+-------+-------+-------+
|
|
131
|
+
# | 0 | # | . | # | . | # | 2 |
|
|
132
|
+
# +-------+-------+-------+-------+-------+-------+-------+
|
|
133
|
+
# | . | # | . | # | 1 | # | . |
|
|
134
|
+
# +-------+-------+-------+-------+-------+-------+-------+
|
|
135
|
+
# | 4 | # | 9 | 5 | 3 | # | 7 |
|
|
136
|
+
# +-------+-------+-------+-------+-------+-------+-------+
|
|
137
|
+
# | . | . | 8 | # | . | # | 7 |
|
|
138
|
+
# +-------+-------+-------+-------+-------+-------+-------+
|
|
139
|
+
# | 6 | # | 2 | # | 9 | # | 3 |
|
|
140
|
+
# +-------+-------+-------+-------+-------+-------+-------+
|
|
141
|
+
# | 8 | # | 0 | # | . | # | A_1,5 |
|
|
142
|
+
# +-------+-------+-------+-------+-------+-------+-------+
|
|
143
|
+
# | . | # | 1 | # | 4 | A_2 | A_0,6 |
|
|
144
|
+
# +-------+-------+-------+-------+-------+-------+-------+
|
|
145
|
+
# Action for agent_0
|
|
146
|
+
# 3
|
|
147
|
+
# Action for agent_1
|
|
148
|
+
# 0
|
|
149
|
+
# Action for agent_2
|
|
150
|
+
# 1
|
|
151
|
+
# Gives
|
|
152
|
+
# {'agent_0': Array(-100., dtype=float32), 'agent_1': Array(-100., dtype=float32), 'agent_2': Array(-100., dtype=float32)}
|
|
153
|
+
# {'__all__': Array(True, dtype=bool), 'agent_0': Array(True, dtype=bool), 'agent_1': Array(True, dtype=bool), 'agent_2': Array(True, dtype=bool)}
|
|
154
|
+
def compute_collisions(mask):
|
|
155
|
+
positions = jnp.where(mask[:, None], state.agent_positions, new_agent_pos)
|
|
156
|
+
|
|
157
|
+
collision_grid = jnp.zeros(self.grid_shape)
|
|
158
|
+
collision_grid, _ = jax.lax.scan(
|
|
159
|
+
lambda grid, pos: (grid.at[pos[0], pos[1]].add(1), None),
|
|
160
|
+
collision_grid,
|
|
161
|
+
positions,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
collision_mask = collision_grid > 1
|
|
165
|
+
|
|
166
|
+
collisions = jax.vmap(lambda p: collision_mask[p[0], p[1]])(positions)
|
|
167
|
+
return jnp.logical_and(state.is_alive, collisions)
|
|
168
|
+
|
|
169
|
+
collisions = jax.lax.while_loop(
|
|
170
|
+
lambda mask: jnp.any(compute_collisions(mask)),
|
|
171
|
+
lambda mask: jnp.logical_or(mask, compute_collisions(mask)),
|
|
172
|
+
jnp.zeros((self.n_agents,), dtype=bool)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if self.collision_reward is None:
|
|
176
|
+
new_agent_pos = jnp.where(collisions[:, None], state.agent_positions, new_agent_pos)
|
|
177
|
+
collisions = jnp.full(collisions.shape, False)
|
|
178
|
+
|
|
179
|
+
# Handle swaps
|
|
180
|
+
def compute_swaps(original_positions, new_positions):
|
|
181
|
+
original_pos_expanded = jnp.expand_dims(original_positions, axis=0)
|
|
182
|
+
new_pos_expanded = jnp.expand_dims(new_positions, axis=1)
|
|
183
|
+
|
|
184
|
+
swap_mask = (original_pos_expanded == new_pos_expanded).all(axis=-1)
|
|
185
|
+
swap_mask = jnp.fill_diagonal(swap_mask, False, inplace=False)
|
|
186
|
+
|
|
187
|
+
swap_pairs = jnp.logical_and(swap_mask, swap_mask.T)
|
|
188
|
+
|
|
189
|
+
swaps = jnp.any(swap_pairs, axis=0)
|
|
190
|
+
return swaps
|
|
191
|
+
|
|
192
|
+
swaps = compute_swaps(state.agent_positions, new_agent_pos)
|
|
193
|
+
new_agent_pos = jnp.where(swaps[:, None], state.agent_positions, new_agent_pos)
|
|
194
|
+
|
|
195
|
+
_rewards = jnp.zeros((self.n_agents,), dtype=jnp.float32)
|
|
196
|
+
if self.collision_reward is not None:
|
|
197
|
+
_rewards = jnp.where(jnp.logical_and(state.is_alive, collisions), self.collision_reward, _rewards)
|
|
198
|
+
rewards = {agent: _rewards[i] for i, agent in enumerate(self.agents)}
|
|
199
|
+
|
|
200
|
+
is_wall_disabled = jnp.empty((0, 2), dtype=bool)
|
|
201
|
+
if self.init_state is not None:
|
|
202
|
+
is_wall_disabled = self.compute_disabled_walls(new_agent_pos, state.wall_positions, state.button_positions)
|
|
203
|
+
|
|
204
|
+
new_state = TokenEnvState(
|
|
205
|
+
agent_positions=new_agent_pos,
|
|
206
|
+
token_positions=state.token_positions,
|
|
207
|
+
wall_positions=state.wall_positions,
|
|
208
|
+
is_wall_disabled=is_wall_disabled,
|
|
209
|
+
button_positions=state.button_positions,
|
|
210
|
+
is_alive=jnp.logical_and(state.is_alive, jnp.logical_not(collisions)),
|
|
211
|
+
time=state.time + 1
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
_dones = jnp.logical_or(collisions, new_state.time >= self.max_steps_in_episode)
|
|
215
|
+
dones = {a: _dones[i] for i, a in enumerate(self.agents)}
|
|
216
|
+
dones.update({"__all__": jnp.all(_dones)})
|
|
217
|
+
|
|
218
|
+
obs = self.get_obs(state=new_state)
|
|
219
|
+
info = {}
|
|
220
|
+
|
|
221
|
+
return obs, new_state, rewards, dones, info
|
|
222
|
+
|
|
223
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
224
|
+
def get_obs(
|
|
225
|
+
self,
|
|
226
|
+
state: TokenEnvState
|
|
227
|
+
) -> Dict[str, chex.Array]:
|
|
228
|
+
|
|
229
|
+
def obs_for_agent(i):
|
|
230
|
+
base = jnp.zeros(self.obs_shape, dtype=jnp.uint8)
|
|
231
|
+
# ref = self.grid_shape_arr // 2
|
|
232
|
+
ref = jnp.array([0, 0])
|
|
233
|
+
offset = ref - state.agent_positions[i]
|
|
234
|
+
idx_offset = 0
|
|
235
|
+
b = base
|
|
236
|
+
|
|
237
|
+
def place_agent(val):
|
|
238
|
+
rel = (state.agent_positions[i] + offset) % self.grid_shape_arr
|
|
239
|
+
return val.at[idx_offset, rel[0], rel[1]].set(1) # Is agent?
|
|
240
|
+
b = place_agent(b)
|
|
241
|
+
idx_offset += 1
|
|
242
|
+
|
|
243
|
+
if self.init_state is not None:
|
|
244
|
+
def place_wall(val):
|
|
245
|
+
rel = (state.wall_positions + offset) % self.grid_shape_arr
|
|
246
|
+
return val.at[
|
|
247
|
+
idx_offset, rel[:, 0], rel[:, 1]
|
|
248
|
+
].set(1).at[ # Is wall?
|
|
249
|
+
idx_offset + 1, rel[:, 0], rel[:, 1]
|
|
250
|
+
].set(jnp.logical_not(state.is_wall_disabled).astype(jnp.uint8)) # Is wall blocking?
|
|
251
|
+
b = place_wall(b)
|
|
252
|
+
idx_offset += 2
|
|
253
|
+
|
|
254
|
+
if self.n_tokens > 0:
|
|
255
|
+
def place_token(token_idx, val):
|
|
256
|
+
rel = (state.token_positions[token_idx] + offset) % self.grid_shape_arr
|
|
257
|
+
return val.at[idx_offset + token_idx, rel[:, 0], rel[:, 1]].set(1) # Is token?
|
|
258
|
+
b = jax.lax.fori_loop(0, self.n_tokens, place_token, b)
|
|
259
|
+
idx_offset += self.n_tokens
|
|
260
|
+
|
|
261
|
+
if self.n_agents > 1:
|
|
262
|
+
def place_other(other_idx, val):
|
|
263
|
+
rel = (state.agent_positions[other_idx + (other_idx >= i)] + offset) % self.grid_shape_arr
|
|
264
|
+
return val.at[idx_offset + other_idx, rel[0], rel[1]].set(1) # Is other agent?
|
|
265
|
+
b = jax.lax.fori_loop(0, self.n_agents - 1, place_other, b)
|
|
266
|
+
idx_offset += self.n_agents - 1
|
|
267
|
+
|
|
268
|
+
if self.n_buttons > 0:
|
|
269
|
+
def place_button(button_idx, val):
|
|
270
|
+
is_door = jnp.any(
|
|
271
|
+
jnp.all(
|
|
272
|
+
state.button_positions[button_idx][:, None, :] == state.wall_positions[None, :, :]
|
|
273
|
+
, axis=-1)
|
|
274
|
+
, axis=-1) # Buttons are considered doors if they are in a wall.
|
|
275
|
+
rel = (state.button_positions[button_idx] + offset) % self.grid_shape_arr
|
|
276
|
+
return val.at[
|
|
277
|
+
idx_offset + 3 * button_idx, rel[:, 0], rel[:, 1]
|
|
278
|
+
].set(1).at[ # Is button-door pair?
|
|
279
|
+
idx_offset + 3 * button_idx + 1, rel[:, 0], rel[:, 1]
|
|
280
|
+
].set(jnp.logical_not(is_door).astype(jnp.uint8)).at[ # Is button?
|
|
281
|
+
idx_offset + 3 * button_idx + 2, rel[:, 0], rel[:, 1]
|
|
282
|
+
].set(is_door.astype(jnp.uint8)) # Is door?
|
|
283
|
+
|
|
284
|
+
b = jax.lax.fori_loop(0, self.n_buttons, place_button, b)
|
|
285
|
+
|
|
286
|
+
return jnp.where(jnp.logical_or(jnp.logical_not(self.black_death), state.is_alive[i]), b, base)
|
|
287
|
+
|
|
288
|
+
obs = jax.vmap(obs_for_agent)(jnp.arange(self.n_agents))
|
|
289
|
+
return {agent: obs[i] for i, agent in enumerate(self.agents)}
|
|
290
|
+
|
|
291
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
292
|
+
def label_f(self, state: TokenEnvState) -> Dict[str, int]:
|
|
293
|
+
|
|
294
|
+
diffs = state.agent_positions[:, None, None, :] - state.token_positions[None, :, :, :]
|
|
295
|
+
matches = jnp.all(diffs == 0, axis=-1)
|
|
296
|
+
matches_any = jnp.any(matches, axis=-1)
|
|
297
|
+
|
|
298
|
+
has_match = jnp.any(matches_any, axis=1)
|
|
299
|
+
token_idx = jnp.argmax(matches_any, axis=1)
|
|
300
|
+
|
|
301
|
+
agent_token_matches = jnp.where(jnp.logical_and(has_match, state.is_alive), token_idx, -1)
|
|
302
|
+
|
|
303
|
+
return {self.agents[agent_idx]: token_idx for agent_idx, token_idx in enumerate(agent_token_matches)}
|
|
304
|
+
|
|
305
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
306
|
+
def sample_init_state(
|
|
307
|
+
self,
|
|
308
|
+
key: chex.PRNGKey
|
|
309
|
+
) -> Tuple[Dict[str, chex.Array], TokenEnvState]:
|
|
310
|
+
if self.fixed_map_seed is not None:
|
|
311
|
+
key = jax.random.PRNGKey(self.fixed_map_seed)
|
|
312
|
+
|
|
313
|
+
grid_points = jnp.stack(jnp.meshgrid(jnp.arange(self.grid_shape[0]), jnp.arange(self.grid_shape[1])), -1)
|
|
314
|
+
grid_flat = grid_points.reshape(-1, 2)
|
|
315
|
+
|
|
316
|
+
key, subkey = jax.random.split(key)
|
|
317
|
+
perm = jax.random.permutation(subkey, grid_flat.shape[0])
|
|
318
|
+
|
|
319
|
+
agent_positions = grid_flat[perm][:self.n_agents]
|
|
320
|
+
token_positions = grid_flat[perm][self.n_agents: self.n_agents + self.n_tokens * self.n_token_repeat].reshape(self.n_tokens, self.n_token_repeat, 2)
|
|
321
|
+
|
|
322
|
+
return TokenEnvState(
|
|
323
|
+
agent_positions=agent_positions,
|
|
324
|
+
token_positions=token_positions,
|
|
325
|
+
wall_positions=jnp.empty((0, 2), dtype=jnp.int32),
|
|
326
|
+
is_wall_disabled=jnp.empty((0, 2), dtype=bool),
|
|
327
|
+
button_positions=jnp.empty((0, 2), dtype=jnp.int32),
|
|
328
|
+
is_alive=jnp.ones((self.n_agents,), dtype=bool),
|
|
329
|
+
time=0
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
333
|
+
def compute_disabled_walls(self, agent_positions, wall_positions, button_positions):
|
|
334
|
+
def _compute_on_buttons(button_pos, agent_positions):
|
|
335
|
+
return jnp.any(
|
|
336
|
+
jnp.any(
|
|
337
|
+
jnp.all(
|
|
338
|
+
button_pos[:, None, :] == agent_positions[None, :, :] # [M, N, 2]
|
|
339
|
+
, axis=-1) # [M, N,]
|
|
340
|
+
, axis=-1) # [M,]
|
|
341
|
+
, axis=-1) # [1,]
|
|
342
|
+
on_buttons = jax.vmap(_compute_on_buttons, in_axes=(0, None))(button_positions, agent_positions)
|
|
343
|
+
def _compute_disabled_walls(wall_pos, on_buttons, button_positions):
|
|
344
|
+
# Compare each wall_pos to each button coordinate
|
|
345
|
+
eq = jnp.all(button_positions == wall_pos, axis=-1) # (n_buttons, n_button_repeat)
|
|
346
|
+
# A wall is disabled if *any* matching button is pressed
|
|
347
|
+
return jnp.any(jnp.logical_and(on_buttons, jnp.any(eq, axis=-1)))
|
|
348
|
+
return jax.vmap(_compute_disabled_walls, in_axes=(0, None, None))(wall_positions, on_buttons, button_positions)
|
|
349
|
+
|
|
350
|
+
def parse(self, layout: str) -> TokenEnvState:
|
|
351
|
+
# Example layout:
|
|
352
|
+
# [ 8 ][ ][ ][ ][ ][ ][ ][ # ][ 0 ][ ][ ][ 1 ]
|
|
353
|
+
# [ ][ ][ ][ ][ ][ ][ ][ # ][ ][ ][ ][ ]
|
|
354
|
+
# [ ][ ][ b ][ ][ ][ ][ ][ # ][ ][ ][ ][ ]
|
|
355
|
+
# [ ][ ][ ][ ][ ][ ][ ][ # ][ 3 ][ ][ ][ 2 ]
|
|
356
|
+
# [ ][ ][ ][ ][ ][ ][ ][ # ][ # ][ # ][#,a][ # ]
|
|
357
|
+
# [ A ][ ][ ][ ][ ][ ][ ][ ][ ][ ][ ][ ]
|
|
358
|
+
# [ B ][ ][ ][ ][ ][ ][ ][ ][ ][ ][ ][ ]
|
|
359
|
+
# [ ][ ][ ][ ][ ][ ][ ][ # ][ # ][ # ][#,b][ # ]
|
|
360
|
+
# [ ][ ][ ][ ][ ][ ][ ][ # ][ 4 ][ ][ ][ 5 ]
|
|
361
|
+
# [ ][ ][ a ][ ][ ][ ][ ][ # ][ ][ ][ ][ ]
|
|
362
|
+
# [ ][ ][ ][ ][ ][ ][ ][ # ][ ][ ][ ][ ]
|
|
363
|
+
# [ 9 ][ ][ ][ ][ ][ ][ ][ # ][ 7 ][ ][ ][ 6 ]
|
|
364
|
+
# where each [] indicates a cell, uppercase letters indicate agents,
|
|
365
|
+
# e.g., A and B, and lower case letters indicate "sync" points which
|
|
366
|
+
# are more like doors and buttons that open the doors, eg [#,a] is a
|
|
367
|
+
# door that is open if there is an agent on the cell [ a ] and closed;
|
|
368
|
+
# otherwise, and cells with # indicate walls.
|
|
369
|
+
|
|
370
|
+
# --- parse into a 2D list of cell contents (strings inside brackets) ---
|
|
371
|
+
lines = [ln.strip() for ln in layout.strip().splitlines() if ln.strip()]
|
|
372
|
+
rows: list[list[str]] = []
|
|
373
|
+
|
|
374
|
+
for ln in lines:
|
|
375
|
+
cells = []
|
|
376
|
+
i = 0
|
|
377
|
+
while i < len(ln):
|
|
378
|
+
if ln[i] == "[":
|
|
379
|
+
j = ln.find("]", i + 1)
|
|
380
|
+
if j == -1:
|
|
381
|
+
raise ValueError("Malformed layout: missing closing ']' in a row.")
|
|
382
|
+
cells.append(ln[i + 1:j].strip())
|
|
383
|
+
i = j + 1
|
|
384
|
+
else:
|
|
385
|
+
i += 1
|
|
386
|
+
if cells:
|
|
387
|
+
rows.append(cells)
|
|
388
|
+
|
|
389
|
+
if not rows:
|
|
390
|
+
raise ValueError("Parsed layout is empty.")
|
|
391
|
+
|
|
392
|
+
H = len(rows)
|
|
393
|
+
W = len(rows[0])
|
|
394
|
+
if any(len(r) != W for r in rows):
|
|
395
|
+
raise ValueError("All rows in the layout must have the same number of cells.")
|
|
396
|
+
|
|
397
|
+
# --- collect info ---
|
|
398
|
+
wall_positions: list[tuple[int, int]] = []
|
|
399
|
+
agent_positions: dict[int, tuple[int, int]] = {}
|
|
400
|
+
token_positions: dict[int, list[tuple[int, int]]] = {}
|
|
401
|
+
button_positions: dict[int, list[tuple[int, int]]] = {}
|
|
402
|
+
max_token_id = -1
|
|
403
|
+
max_button_id = -1
|
|
404
|
+
|
|
405
|
+
def _is_wall(c: str) -> bool:
|
|
406
|
+
return "#" == c
|
|
407
|
+
|
|
408
|
+
def _is_agent(c: str) -> bool:
|
|
409
|
+
is_agent_in_cell = [
|
|
410
|
+
a == c
|
|
411
|
+
for a in [chr(ord("A") + i) for i in range(ord("Z") - ord("A") + 1)]
|
|
412
|
+
]
|
|
413
|
+
return sum(is_agent_in_cell) > 0
|
|
414
|
+
|
|
415
|
+
def _get_agent_idx(c: str) -> bool:
|
|
416
|
+
is_agent_in_cell = [
|
|
417
|
+
a == c
|
|
418
|
+
for a in [chr(ord("A") + i) for i in range(ord("Z") - ord("A") + 1)]
|
|
419
|
+
]
|
|
420
|
+
return np.argmax(is_agent_in_cell)
|
|
421
|
+
|
|
422
|
+
def _is_token(c: str) -> bool:
|
|
423
|
+
try:
|
|
424
|
+
int(c)
|
|
425
|
+
return True
|
|
426
|
+
except ValueError:
|
|
427
|
+
return False
|
|
428
|
+
|
|
429
|
+
def _is_button(c: str) -> bool:
|
|
430
|
+
is_button_in_cell = [
|
|
431
|
+
a == c
|
|
432
|
+
for a in [chr(ord("a") + i) for i in range(ord("z") - ord("a") + 1)]
|
|
433
|
+
]
|
|
434
|
+
return sum(is_button_in_cell) > 0
|
|
435
|
+
|
|
436
|
+
def _get_button_idx(c: str) -> bool:
|
|
437
|
+
is_button_in_cell = [
|
|
438
|
+
a == c
|
|
439
|
+
for a in [chr(ord("a") + i) for i in range(ord("z") - ord("a") + 1)]
|
|
440
|
+
]
|
|
441
|
+
return np.argmax(is_button_in_cell)
|
|
442
|
+
|
|
443
|
+
for r in range(H):
|
|
444
|
+
for c in range(W):
|
|
445
|
+
cell = rows[r][c]
|
|
446
|
+
has_wall = False
|
|
447
|
+
has_agent = False
|
|
448
|
+
has_token = False
|
|
449
|
+
has_button = False
|
|
450
|
+
for content in cell.split(","):
|
|
451
|
+
|
|
452
|
+
if _is_wall(content):
|
|
453
|
+
if has_wall:
|
|
454
|
+
raise ValueError(f"One wall per cell.")
|
|
455
|
+
has_wall = True
|
|
456
|
+
wall_positions.append((r, c))
|
|
457
|
+
|
|
458
|
+
if _is_agent(content):
|
|
459
|
+
if has_agent:
|
|
460
|
+
raise ValueError(f"One agent per cell.")
|
|
461
|
+
has_agent = True
|
|
462
|
+
idx = _get_agent_idx(content)
|
|
463
|
+
if idx in agent_positions:
|
|
464
|
+
raise ValueError(f"Duplicate placement for agent '{content}'.")
|
|
465
|
+
agent_positions[idx] = (r, c)
|
|
466
|
+
|
|
467
|
+
if _is_token(content):
|
|
468
|
+
if has_token:
|
|
469
|
+
raise ValueError(f"One token per cell.")
|
|
470
|
+
has_token = True
|
|
471
|
+
tok_id = int(content)
|
|
472
|
+
max_token_id = max(max_token_id, tok_id)
|
|
473
|
+
token_positions.setdefault(tok_id, []).append((r, c))
|
|
474
|
+
|
|
475
|
+
if _is_button(content):
|
|
476
|
+
if has_button:
|
|
477
|
+
raise ValueError(f"One button per cell.")
|
|
478
|
+
has_button = True
|
|
479
|
+
idx = _get_button_idx(content)
|
|
480
|
+
max_button_id = max(max_button_id, idx)
|
|
481
|
+
button_positions.setdefault(idx, []).append((r, c))
|
|
482
|
+
|
|
483
|
+
assert not (has_wall and has_agent)
|
|
484
|
+
assert not (has_wall and has_token)
|
|
485
|
+
assert not (has_token and has_button)
|
|
486
|
+
|
|
487
|
+
# --- override environment settings ---
|
|
488
|
+
self.grid_shape = (H, W)
|
|
489
|
+
self.grid_shape_arr = jnp.array(self.grid_shape)
|
|
490
|
+
|
|
491
|
+
self.n_agents = len(agent_positions)
|
|
492
|
+
self.agents = [f"agent_{i}" for i in range(self.n_agents)]
|
|
493
|
+
|
|
494
|
+
self.n_tokens = max_token_id + 1 if max_token_id >= 0 else 0
|
|
495
|
+
self.n_token_repeat = max((len(v) for v in token_positions.values()), default=0)
|
|
496
|
+
token_positions_np = np.full((self.n_tokens, self.n_token_repeat, 2), -1, dtype=np.int32)
|
|
497
|
+
for tid in range(self.n_tokens):
|
|
498
|
+
coords = token_positions.get(tid, [])
|
|
499
|
+
for k, (r, c) in enumerate(coords[: self.n_token_repeat]):
|
|
500
|
+
token_positions_np[tid, k] = (r, c)
|
|
501
|
+
|
|
502
|
+
agent_positions_np = np.full((self.n_agents, 2), -1, dtype=np.int32)
|
|
503
|
+
for idx, pos in agent_positions.items():
|
|
504
|
+
agent_positions_np[idx] = pos
|
|
505
|
+
|
|
506
|
+
wall_positions_np = np.array(wall_positions, dtype=np.int32) if wall_positions else np.empty((0, 2), dtype=np.int32)
|
|
507
|
+
|
|
508
|
+
self.n_buttons = max_button_id + 1 if max_button_id >= 0 else 0
|
|
509
|
+
self.n_button_repeat = max((len(v) for v in button_positions.values()), default=0)
|
|
510
|
+
button_positions_np = np.full((self.n_buttons, self.n_button_repeat, 2), -1, dtype=np.int32)
|
|
511
|
+
for bid in range(self.n_buttons):
|
|
512
|
+
coords = button_positions.get(bid, [])
|
|
513
|
+
for k, (r, c) in enumerate(coords[: self.n_button_repeat]):
|
|
514
|
+
button_positions_np[bid, k] = (r, c)
|
|
515
|
+
|
|
516
|
+
agent_positions_jnp = jnp.array(agent_positions_np)
|
|
517
|
+
wall_positions_jnp = jnp.array(wall_positions_np)
|
|
518
|
+
button_positions_jnp = jnp.array(button_positions_np)
|
|
519
|
+
|
|
520
|
+
# --- return state ---
|
|
521
|
+
return TokenEnvState(
|
|
522
|
+
agent_positions=agent_positions_jnp,
|
|
523
|
+
token_positions=jnp.array(token_positions_np),
|
|
524
|
+
wall_positions=wall_positions_jnp,
|
|
525
|
+
is_wall_disabled=self.compute_disabled_walls(agent_positions_jnp, wall_positions_jnp, button_positions_jnp),
|
|
526
|
+
button_positions=button_positions_jnp,
|
|
527
|
+
is_alive=jnp.ones((self.n_agents,), dtype=bool),
|
|
528
|
+
time=0,
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
def render(self, state: TokenEnvState):
|
|
532
|
+
empty_cell = "."
|
|
533
|
+
wall_cell = "#"
|
|
534
|
+
grid = np.full(self.grid_shape, empty_cell, dtype=object)
|
|
535
|
+
|
|
536
|
+
for disabled, pos in zip(state.is_wall_disabled, state.wall_positions):
|
|
537
|
+
if not disabled:
|
|
538
|
+
grid[pos[0], pos[1]] = f"{wall_cell}"
|
|
539
|
+
|
|
540
|
+
for token, positions in enumerate(state.token_positions):
|
|
541
|
+
for pos in positions:
|
|
542
|
+
grid[pos[0], pos[1]] = f"{token}"
|
|
543
|
+
|
|
544
|
+
for button, positions in enumerate(state.button_positions):
|
|
545
|
+
for pos in positions:
|
|
546
|
+
current = grid[pos[0], pos[1]]
|
|
547
|
+
if current == empty_cell:
|
|
548
|
+
grid[pos[0], pos[1]] = f"b_{button}"
|
|
549
|
+
else:
|
|
550
|
+
grid[pos[0], pos[1]] = f"{current},b_{button}"
|
|
551
|
+
|
|
552
|
+
for agent in range(self.n_agents):
|
|
553
|
+
pos = state.agent_positions[agent]
|
|
554
|
+
current = grid[pos[0], pos[1]]
|
|
555
|
+
if current == empty_cell:
|
|
556
|
+
grid[pos[0], pos[1]] = f"A_{agent}"
|
|
557
|
+
else:
|
|
558
|
+
grid[pos[0], pos[1]] = f"A_{agent},{current}"
|
|
559
|
+
|
|
560
|
+
max_width = max(len(str(cell)) for row in grid for cell in row)
|
|
561
|
+
|
|
562
|
+
out = ""
|
|
563
|
+
h_line = "+" + "+".join(["-" * (max_width + 2) for _ in range(self.grid_shape[1])]) + "+"
|
|
564
|
+
out += h_line + "\n"
|
|
565
|
+
for row in grid:
|
|
566
|
+
row_str = "| " + " | ".join(f"{str(cell):<{max_width}}" for cell in row) + " |"
|
|
567
|
+
out += row_str + "\n"
|
|
568
|
+
out += h_line + "\n"
|
|
569
|
+
|
|
570
|
+
print(out)
|
|
571
|
+
|