dfa-gym 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dfa_gym/__init__.py +5 -15
- dfa_gym/dfa_bisim_env.py +121 -0
- dfa_gym/dfa_wrapper.py +185 -52
- dfa_gym/env.py +168 -0
- dfa_gym/maps/2buttons_2agents.pdf +0 -0
- dfa_gym/maps/2rooms_2agents.pdf +0 -0
- dfa_gym/maps/4buttons_4agents.pdf +0 -0
- dfa_gym/maps/4rooms_4agents.pdf +0 -0
- dfa_gym/robot.png +0 -0
- dfa_gym/spaces.py +156 -0
- dfa_gym/token_env.py +571 -0
- dfa_gym/utils.py +266 -0
- dfa_gym-0.2.0.dist-info/METADATA +93 -0
- dfa_gym-0.2.0.dist-info/RECORD +16 -0
- {dfa_gym-0.1.0.dist-info → dfa_gym-0.2.0.dist-info}/WHEEL +1 -1
- dfa_gym/dfa_env.py +0 -45
- dfa_gym-0.1.0.dist-info/METADATA +0 -11
- dfa_gym-0.1.0.dist-info/RECORD +0 -7
- {dfa_gym-0.1.0.dist-info → dfa_gym-0.2.0.dist-info}/licenses/LICENSE +0 -0
dfa_gym/utils.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import matplotlib.image as mpimg
|
|
4
|
+
import matplotlib.patches as patches
|
|
5
|
+
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
|
6
|
+
from matplotlib.animation import FuncAnimation, PillowWriter
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def parse_map(map_lines):
|
|
10
|
+
"""Parses the ASCII map into a 2D grid of cells."""
|
|
11
|
+
grid = []
|
|
12
|
+
for line in map_lines:
|
|
13
|
+
cells = re.findall(r"\[(.*?)\]", line)
|
|
14
|
+
if cells:
|
|
15
|
+
grid.append([c.strip() for c in cells])
|
|
16
|
+
return grid
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def visualize(layout, figsize, cell_size=1, save_path=None, trace=None):
|
|
20
|
+
map_lines = layout.splitlines()
|
|
21
|
+
grid = parse_map(map_lines)
|
|
22
|
+
n_rows, n_cols = len(grid), len(grid[0])
|
|
23
|
+
|
|
24
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
25
|
+
ax.set_xlim(0, n_cols)
|
|
26
|
+
ax.set_ylim(0, n_rows)
|
|
27
|
+
ax.set_aspect("equal")
|
|
28
|
+
ax.axis("off")
|
|
29
|
+
|
|
30
|
+
for r in range(n_rows):
|
|
31
|
+
for c in range(n_cols):
|
|
32
|
+
content = grid[r][c]
|
|
33
|
+
x, y = c, n_rows - r - 1
|
|
34
|
+
|
|
35
|
+
# background
|
|
36
|
+
ax.add_patch(patches.Rectangle(
|
|
37
|
+
(x, y), cell_size, cell_size,
|
|
38
|
+
facecolor="lightgray", edgecolor="white", lw=1
|
|
39
|
+
))
|
|
40
|
+
agent_positions = {}
|
|
41
|
+
wall_positions = {}
|
|
42
|
+
for r in range(n_rows):
|
|
43
|
+
for c in range(n_cols):
|
|
44
|
+
content = grid[r][c]
|
|
45
|
+
x, y = c, n_rows - r - 1
|
|
46
|
+
|
|
47
|
+
if not content:
|
|
48
|
+
continue
|
|
49
|
+
|
|
50
|
+
if content == "#": # wall
|
|
51
|
+
wall_positions[(x,y)] = "dimgray"
|
|
52
|
+
# ax.add_patch(patches.Rectangle(
|
|
53
|
+
# (x, y), cell_size, cell_size,
|
|
54
|
+
# facecolor="dimgray", edgecolor="black", lw=1.5
|
|
55
|
+
# ))
|
|
56
|
+
|
|
57
|
+
elif content.isupper(): # agents
|
|
58
|
+
agent_positions[content] = (x + 0.5, y + 0.5)
|
|
59
|
+
|
|
60
|
+
# ax.text(x + 0.5, y + 0.5, "8",
|
|
61
|
+
# ha="center", va="center",
|
|
62
|
+
# fontsize=14, weight="bold")
|
|
63
|
+
|
|
64
|
+
elif content.isdigit(): # tokens
|
|
65
|
+
ax.add_patch(patches.Circle(
|
|
66
|
+
(x + 0.5, y + 0.5), 0.4,
|
|
67
|
+
facecolor="gold", edgecolor="orange", lw=1.5
|
|
68
|
+
))
|
|
69
|
+
ax.text(x + 0.5, y + 0.5, content,
|
|
70
|
+
ha="center", va="center",
|
|
71
|
+
fontsize=24, color="black", weight="bold")
|
|
72
|
+
|
|
73
|
+
elif content.islower(): # sync button
|
|
74
|
+
if "a" in content:
|
|
75
|
+
color = "red"
|
|
76
|
+
elif "b" in content:
|
|
77
|
+
color = "green"
|
|
78
|
+
elif "c" in content:
|
|
79
|
+
color = "blue"
|
|
80
|
+
elif "d" in content:
|
|
81
|
+
color = "pink"
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError
|
|
84
|
+
# color = "crimson"
|
|
85
|
+
if "#" in content:
|
|
86
|
+
wall_positions[(x,y)] = color
|
|
87
|
+
# ax.add_patch(patches.Rectangle(
|
|
88
|
+
# (x, y), cell_size, cell_size,
|
|
89
|
+
# facecolor=color, edgecolor="black", lw=1.5,
|
|
90
|
+
# hatch="||", hatch_linewidth=3, fill=True
|
|
91
|
+
# ))
|
|
92
|
+
else:
|
|
93
|
+
ax.add_patch(patches.Rectangle(
|
|
94
|
+
(x, y), cell_size, cell_size,
|
|
95
|
+
facecolor=color, edgecolor="black", lw=1.5
|
|
96
|
+
))
|
|
97
|
+
|
|
98
|
+
elif "," in content: # door like "#,a"
|
|
99
|
+
parts = [p.strip() for p in content.split(",")]
|
|
100
|
+
ax.add_patch(patches.Rectangle(
|
|
101
|
+
(x, y), cell_size, cell_size,
|
|
102
|
+
facecolor="firebrick", edgecolor="black", lw=1.5
|
|
103
|
+
))
|
|
104
|
+
# for p in parts:
|
|
105
|
+
# if p.islower():
|
|
106
|
+
# ax.text(x + 0.5, y + 0.5, p,
|
|
107
|
+
# ha="center", va="center",
|
|
108
|
+
# fontsize=9, color="white")
|
|
109
|
+
|
|
110
|
+
# if trace is not None:
|
|
111
|
+
# #TODO: Trace is a list of agent positions, where for n agents with trace length L, trace contains L many agent position entries each is a n by 2 vector giving agent positions.
|
|
112
|
+
# # Draw this trace on the map!
|
|
113
|
+
if trace is not None:
|
|
114
|
+
|
|
115
|
+
n_agents = len(agent_positions.keys())
|
|
116
|
+
L = len(trace)
|
|
117
|
+
|
|
118
|
+
# Load robot image once
|
|
119
|
+
robot_img = mpimg.imread('robot.png')
|
|
120
|
+
zoom = 0.05 # adjust as needed
|
|
121
|
+
|
|
122
|
+
# Optional: add labels to track agents
|
|
123
|
+
agent_labels = [str(i + 1) for i in range(n_agents)]
|
|
124
|
+
|
|
125
|
+
# Store artists for cleanup each frame
|
|
126
|
+
current_boxes = []
|
|
127
|
+
current_texts = []
|
|
128
|
+
current_walls = []
|
|
129
|
+
current_timestep = [] # <-- NEW
|
|
130
|
+
|
|
131
|
+
def update(frame):
|
|
132
|
+
# Remove previous robot images and texts
|
|
133
|
+
for ab in current_boxes:
|
|
134
|
+
ab.remove()
|
|
135
|
+
current_boxes.clear()
|
|
136
|
+
|
|
137
|
+
for txt in current_texts:
|
|
138
|
+
txt.remove()
|
|
139
|
+
current_texts.clear()
|
|
140
|
+
|
|
141
|
+
for wall in current_walls:
|
|
142
|
+
wall.remove()
|
|
143
|
+
current_walls.clear()
|
|
144
|
+
|
|
145
|
+
for ts in current_timestep: # <-- remove timestep text
|
|
146
|
+
ts.remove()
|
|
147
|
+
current_timestep.clear()
|
|
148
|
+
|
|
149
|
+
# Add robot images and labels for this frame
|
|
150
|
+
for agent_idx in range(n_agents):
|
|
151
|
+
pos = trace[frame].env_state.agent_positions[agent_idx]
|
|
152
|
+
x = pos[1] + 0.5
|
|
153
|
+
y = n_rows - pos[0] - 0.5
|
|
154
|
+
|
|
155
|
+
# robot image
|
|
156
|
+
image_box = OffsetImage(robot_img, zoom=zoom)
|
|
157
|
+
ab = AnnotationBbox(image_box, (x, y), frameon=False)
|
|
158
|
+
ax.add_artist(ab)
|
|
159
|
+
current_boxes.append(ab)
|
|
160
|
+
|
|
161
|
+
# label text
|
|
162
|
+
txt = ax.text(x+0.3, y + 0.3, agent_labels[agent_idx],
|
|
163
|
+
ha='center', va='bottom', color='black', weight='bold', fontsize=10)
|
|
164
|
+
current_texts.append(txt)
|
|
165
|
+
|
|
166
|
+
# Draw walls dynamically
|
|
167
|
+
for i, (x, y) in enumerate(wall_positions):
|
|
168
|
+
color = wall_positions[(x, y)]
|
|
169
|
+
if trace[frame].env_state.is_wall_disabled[i] or color == "dimgray":
|
|
170
|
+
rect = ax.add_patch(patches.Rectangle(
|
|
171
|
+
(x, y), cell_size, cell_size,
|
|
172
|
+
facecolor=color, edgecolor="black", lw=1.5
|
|
173
|
+
))
|
|
174
|
+
else:
|
|
175
|
+
rect = ax.add_patch(patches.Rectangle(
|
|
176
|
+
(x, y), cell_size, cell_size,
|
|
177
|
+
facecolor=color, edgecolor="black", lw=1.5,
|
|
178
|
+
hatch="||", hatch_linewidth=3, fill=True
|
|
179
|
+
))
|
|
180
|
+
current_walls.append(rect)
|
|
181
|
+
|
|
182
|
+
# Add timestep text above the grid
|
|
183
|
+
ts = ax.text(n_cols / 2, n_rows + 0.5, f"Time step: {frame}",
|
|
184
|
+
ha='center', va='bottom', color='black', weight='bold', fontsize=14)
|
|
185
|
+
current_timestep.append(ts)
|
|
186
|
+
|
|
187
|
+
return current_boxes + current_texts + current_walls + current_timestep
|
|
188
|
+
|
|
189
|
+
anim = FuncAnimation(fig, update, frames=L, interval=500, blit=False)
|
|
190
|
+
|
|
191
|
+
if save_path:
|
|
192
|
+
gif_path = save_path.replace(".pdf", ".gif")
|
|
193
|
+
anim.save(gif_path, writer=PillowWriter(fps=2))
|
|
194
|
+
|
|
195
|
+
else:
|
|
196
|
+
|
|
197
|
+
for agent in agent_positions:
|
|
198
|
+
x, y = agent_positions[agent]
|
|
199
|
+
image = plt.imread('robot.png')
|
|
200
|
+
image_box = OffsetImage(image, zoom=0.05)
|
|
201
|
+
ab = AnnotationBbox(image_box, (x, y), frameon=False)
|
|
202
|
+
ax.add_artist(ab)
|
|
203
|
+
|
|
204
|
+
for (x, y) in wall_positions:
|
|
205
|
+
color = wall_positions[(x, y)]
|
|
206
|
+
ax.add_patch(patches.Rectangle(
|
|
207
|
+
(x, y), cell_size, cell_size,
|
|
208
|
+
facecolor=color, edgecolor="black", lw=1.5,
|
|
209
|
+
hatch="||", hatch_linewidth=3, fill=True
|
|
210
|
+
))
|
|
211
|
+
|
|
212
|
+
if save_path:
|
|
213
|
+
plt.savefig(save_path, bbox_inches="tight", dpi=300)
|
|
214
|
+
else:
|
|
215
|
+
plt.show()
|
|
216
|
+
plt.close()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
if __name__ == "__main__":
|
|
220
|
+
layout = """
|
|
221
|
+
[ # ][ # ][ # ][ # ][ # ][ ][ ][ ][ 0 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
|
|
222
|
+
[ # ][ 0 ][ ][ 1 ][#,c][ ][ c ][ ][ A ][ ][ a ][ ][#,a][ 0 ][ ][ 2 ][ # ]
|
|
223
|
+
[ # ][ ][ 4 ][ ][#,c][ ][ c ][ ][ ][ ][ a ][ ][#,a][ ][ 8 ][ ][ # ]
|
|
224
|
+
[ # ][ 3 ][ ][ 2 ][#,c][ ][ c ][ ][ B ][ ][ a ][ ][#,a][ 6 ][ ][ 4 ][ # ]
|
|
225
|
+
[ # ][ # ][ # ][ # ][ # ][ 2 ][ ][ ][ ][ ][ ][ 3 ][ # ][ # ][ # ][ # ][ # ]
|
|
226
|
+
[ # ][ 5 ][ ][ 6 ][#,d][ ][ d ][ ][ C ][ ][ b ][ ][#,b][ 1 ][ ][ 3 ][ # ]
|
|
227
|
+
[ # ][ ][ 9 ][ ][#,d][ ][ d ][ ][ ][ ][ b ][ ][#,b][ ][ 9 ][ ][ # ]
|
|
228
|
+
[ # ][ 8 ][ ][ 7 ][#,d][ ][ d ][ ][ D ][ ][ b ][ ][#,b][ 7 ][ ][ 5 ][ # ]
|
|
229
|
+
[ # ][ # ][ # ][ # ][ # ][ ][ ][ ][ 1 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
|
|
230
|
+
"""
|
|
231
|
+
# layout = """
|
|
232
|
+
# [ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ]
|
|
233
|
+
# [ # ][ 0 ][ ][ 2 ][ # ][ 0 ][ ][ 1 ][ # ][ 5 ][ ][ 6 ][ # ][ 1 ][ ][ 3 ][ # ]
|
|
234
|
+
# [ # ][ ][ ][ ][ # ][ ][ ][ ][ # ][ ][ ][ ][ # ][ ][ ][ ][ # ]
|
|
235
|
+
# [ # ][ ][ ][ ][ # ][ ][ B ][ ][ # ][ ][ a ][ ][ # ][ ][ ][ ][ # ]
|
|
236
|
+
# [ # ][ a ][ 8 ][ A ][#,a][ ][ 4 ][ ][#,a][ ][ 9 ][ ][#,a][ D ][ 9 ][ a ][ # ]
|
|
237
|
+
# [ # ][ ][ ][ ][ # ][ ][ a ][ ][ # ][ ][ C ][ ][ # ][ ][ ][ ][ # ]
|
|
238
|
+
# [ # ][ ][ ][ ][ # ][ ][ ][ ][ # ][ ][ ][ ][ # ][ ][ ][ ][ # ]
|
|
239
|
+
# [ # ][ 6 ][ ][ 4 ][ # ][ 3 ][ ][ 2 ][ # ][ 8 ][ ][ 7 ][ # ][ 7 ][ ][ 5 ][ # ]
|
|
240
|
+
# [ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ]
|
|
241
|
+
# """
|
|
242
|
+
# layout = """
|
|
243
|
+
# [ 0 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
|
|
244
|
+
# [ ][ ][ a ][ ][#,a][ 0 ][ ][ 2 ][ # ]
|
|
245
|
+
# [ A ][ ][ a ][ ][#,a][ ][ 8 ][ ][ # ]
|
|
246
|
+
# [ ][ ][ a ][ ][#,a][ 6 ][ ][ 4 ][ # ]
|
|
247
|
+
# [ 1 ][ ][ ][ 3 ][ # ][ # ][ # ][ # ][ # ]
|
|
248
|
+
# [ ][ ][ b ][ ][#,b][ 1 ][ ][ 3 ][ # ]
|
|
249
|
+
# [ B ][ ][ b ][ ][#,b][ ][ 9 ][ ][ # ]
|
|
250
|
+
# [ ][ ][ b ][ ][#,b][ 7 ][ ][ 5 ][ # ]
|
|
251
|
+
# [ 2 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
|
|
252
|
+
# """
|
|
253
|
+
# layout = """
|
|
254
|
+
# [ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ]
|
|
255
|
+
# [ # ][ 0 ][ ][ 2 ][ # ][ 1 ][ ][ 3 ][ # ]
|
|
256
|
+
# [ # ][ ][ ][ ][ # ][ ][ ][ ][ # ]
|
|
257
|
+
# [ # ][ ][ ][ ][ # ][ ][ ][ ][ # ]
|
|
258
|
+
# [ # ][ a ][ 8 ][ A ][#,a][ B ][ 9 ][ a ][ # ]
|
|
259
|
+
# [ # ][ ][ ][ ][ # ][ ][ ][ ][ # ]
|
|
260
|
+
# [ # ][ ][ ][ ][ # ][ ][ ][ ][ # ]
|
|
261
|
+
# [ # ][ 6 ][ ][ 4 ][ # ][ 7 ][ ][ 5 ][ # ]
|
|
262
|
+
# [ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ][ # ]
|
|
263
|
+
# """
|
|
264
|
+
|
|
265
|
+
# visualize(layout, figsize=(17,9), save_path="maps/4buttons_4agents.pdf")
|
|
266
|
+
visualize(layout, figsize=(17,9))
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: dfa-gym
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Python library for playing DFA bisimulation games and wrapping other RL environments with DFA goals.
|
|
5
|
+
Author-email: Beyazit Yalcinkaya <beyazit@berkeley.edu>
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
Requires-Python: >=3.10
|
|
8
|
+
Requires-Dist: dfax>=0.1.1
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
|
|
11
|
+
# dfa-gym
|
|
12
|
+
|
|
13
|
+
This repo implements (Multi-Agent) Reinforcement Learning environments in JAX for solving objectives given as Deteministic Finite Automata (DFAs). There are three environments:
|
|
14
|
+
|
|
15
|
+
1. `TokenEnv` is a fully observable grid environment with tokens in cells. The grid can be created randomly or from a specific layout. It can be instantiated in both single- and multi-agent settings.
|
|
16
|
+
2. `DFAWrapper` is an environment wrapper assigning tasks represented as Deterministic Finite Automata (DFAs) to the agents in the wrapped environment. DFAs are repsented as [`DFAx`](https://github.com/rad-dfa/dfax) objects.
|
|
17
|
+
3. `DFABisimEnv` is an environment for solving DFA bisimulation games to learn RAD Embeddings, provably correct latent DFA representation, as described in [this paper](https://arxiv.org/pdf/2503.05042).
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
## Installation
|
|
21
|
+
|
|
22
|
+
This package will soon be made pip-installable. In the meantime, pull the repo and and install locally.
|
|
23
|
+
|
|
24
|
+
```
|
|
25
|
+
git clone https://github.com/rad-dfa/dfa-gym.git
|
|
26
|
+
pip install -e dfa-gym
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
## TokenEnv
|
|
30
|
+
|
|
31
|
+
Create a grid world with token and agent positions assigned randomly.
|
|
32
|
+
|
|
33
|
+
```python
|
|
34
|
+
from dfa_gym import TokenEnv
|
|
35
|
+
|
|
36
|
+
env = TokenEnv(
|
|
37
|
+
n_agents=1, # Single agent
|
|
38
|
+
n_tokens=10, # 10 different token types
|
|
39
|
+
n_token_repeat=2, # Each token repeated twice
|
|
40
|
+
grid_shape=(7, 7), # Shape of the grid
|
|
41
|
+
fixed_map_seed=None, # If not None, then samples the same map using the given seed
|
|
42
|
+
max_steps_in_episode=100, # Episode length is 100
|
|
43
|
+
)
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
Create a grid world from a given layout.
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
layout = """
|
|
50
|
+
[ 0 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
|
|
51
|
+
[ ][ ][ a ][ ][#,a][ 0 ][ ][ 2 ][ # ]
|
|
52
|
+
[ A ][ ][ a ][ ][#,a][ ][ 8 ][ ][ # ]
|
|
53
|
+
[ ][ ][ a ][ ][#,a][ 6 ][ ][ 4 ][ # ]
|
|
54
|
+
[ 1 ][ ][ ][ 3 ][ # ][ # ][ # ][ # ][ # ]
|
|
55
|
+
[ ][ ][ b ][ ][#,b][ 1 ][ ][ 3 ][ # ]
|
|
56
|
+
[ B ][ ][ b ][ ][#,b][ ][ 9 ][ ][ # ]
|
|
57
|
+
[ ][ ][ b ][ ][#,b][ 7 ][ ][ 5 ][ # ]
|
|
58
|
+
[ 2 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
|
|
59
|
+
"""
|
|
60
|
+
env = TokenEnv(
|
|
61
|
+
layout=layout, # Set layout, where each [] indicates a cell, uppercase letters are
|
|
62
|
+
# agents, # are walls, and lower case letters are buttons when alone
|
|
63
|
+
# and doors when paired with a wall. For example, [#,a] is a door
|
|
64
|
+
# that is open if an agent is on a [ a ] cell and closed otherwise.
|
|
65
|
+
)
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
## DFAWrapper
|
|
70
|
+
|
|
71
|
+
Wrap a `TokenEnv` instance using `DFAWrapper `.
|
|
72
|
+
|
|
73
|
+
```python
|
|
74
|
+
from dfa_gym import DFAWrapper
|
|
75
|
+
from dfax.samplers import ReachSampler
|
|
76
|
+
|
|
77
|
+
env = DFAWrapper(
|
|
78
|
+
env=TokenEnv(layout=layout),
|
|
79
|
+
sampler=ReachSampler()
|
|
80
|
+
)
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
## DFABisimEnv
|
|
84
|
+
|
|
85
|
+
Create DFA bisimulation game.
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
from dfa_gym import DFABisimEnv
|
|
89
|
+
from dfax.samplers import RADSampler
|
|
90
|
+
|
|
91
|
+
env = DFABisimEnv(sampler=RADSampler())
|
|
92
|
+
```
|
|
93
|
+
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
dfa_gym/__init__.py,sha256=8rauoRND6VqAFw1axw_xcFBOxIzHi9MrrN1d57y_bL4,185
|
|
2
|
+
dfa_gym/dfa_bisim_env.py,sha256=QKh4ebg2HEENSODuY1x77IKySx8MIC9aM9iwA5cAF8o,4394
|
|
3
|
+
dfa_gym/dfa_wrapper.py,sha256=o500-Zl8FELYH4dqL111iuul4THeZWuXrD_RlHVzuBY,6353
|
|
4
|
+
dfa_gym/env.py,sha256=belmfaFHB_dYjnyNCe_zjniSOhldSK_1zXy2W9_FfBU,5415
|
|
5
|
+
dfa_gym/robot.png,sha256=GdWmACflIoWlRBdwJq_rNdInksWqMuEcuLd9KAI8uQE,18616
|
|
6
|
+
dfa_gym/spaces.py,sha256=jIBLrCSEwsnnSQRZ0xVebX-KjdpX_X_5CXQjcX5V7mo,4696
|
|
7
|
+
dfa_gym/token_env.py,sha256=LWMqAh9K8XY33Uuv8epMwo8whz-RJhOHlyANpSQvzoE,24227
|
|
8
|
+
dfa_gym/utils.py,sha256=DE32KxJ7LixEF47f60TX-6q6-uOK6uw7pnnwbhaZFEM,10772
|
|
9
|
+
dfa_gym/maps/2buttons_2agents.pdf,sha256=eH5iwwCWXbyWWJteG4GcnVuA4_8C9jWZTiB6DmKwzHU,32646
|
|
10
|
+
dfa_gym/maps/2rooms_2agents.pdf,sha256=LfUrnDuTyBmhNUp5wljbcxlU85fldRWy0sFVJ7f_aCY,31668
|
|
11
|
+
dfa_gym/maps/4buttons_4agents.pdf,sha256=xABZVV-8Np0gY_6FiAONf-UXnT07_5u34_d6bAOMdo0,57783
|
|
12
|
+
dfa_gym/maps/4rooms_4agents.pdf,sha256=1S_9APYr18sEtIgJtY0SHU44M-z6EgGIA6JTYSrk_lw,56058
|
|
13
|
+
dfa_gym-0.2.0.dist-info/METADATA,sha256=t2G1j_wlOFJuBNypEBFKSO8fr5J9o3bjkVJIYaTDbsE,3157
|
|
14
|
+
dfa_gym-0.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
15
|
+
dfa_gym-0.2.0.dist-info/licenses/LICENSE,sha256=Cvu0BZqt3rcFFv70hcFDgD_y8ryOKW85F-qGRfYI4iM,1071
|
|
16
|
+
dfa_gym-0.2.0.dist-info/RECORD,,
|
dfa_gym/dfa_env.py
DELETED
|
@@ -1,45 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import gymnasium as gym
|
|
3
|
-
from gymnasium import spaces
|
|
4
|
-
from dfa_samplers import DFASampler, RADSampler
|
|
5
|
-
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
__all__ = ["DFAEnv"]
|
|
9
|
-
|
|
10
|
-
class DFAEnv(gym.Env):
|
|
11
|
-
def __init__(
|
|
12
|
-
self,
|
|
13
|
-
sampler: DFASampler | None = None,
|
|
14
|
-
timeout: int = 100
|
|
15
|
-
):
|
|
16
|
-
super().__init__()
|
|
17
|
-
self.sampler = sampler if sampler is not None else RADSampler()
|
|
18
|
-
self.size_bound = self.sampler.get_size_bound()
|
|
19
|
-
self.action_space = spaces.Discrete(self.sampler.n_tokens)
|
|
20
|
-
self.observation_space = spaces.Box(low=0, high=9, shape=(self.size_bound,), dtype=np.int64)
|
|
21
|
-
self.dfa = None
|
|
22
|
-
self.timeout = timeout
|
|
23
|
-
self.t = None
|
|
24
|
-
|
|
25
|
-
def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[np.ndarray, dict[str, Any]]:
|
|
26
|
-
np.random.seed(seed)
|
|
27
|
-
self.dfa = self.sampler.sample()
|
|
28
|
-
self.t = 0
|
|
29
|
-
return self._get_dfa_obs(), {}
|
|
30
|
-
|
|
31
|
-
def step(self, action: int) -> tuple[np.ndarray, int, bool, bool, dict[str, Any]]:
|
|
32
|
-
self.dfa = self.dfa.advance([action]).minimize()
|
|
33
|
-
reward = 0
|
|
34
|
-
if self.dfa._label(self.dfa.start):
|
|
35
|
-
reward = 1
|
|
36
|
-
elif self.dfa.find_word() is None:
|
|
37
|
-
reward = -1
|
|
38
|
-
self.t += 1
|
|
39
|
-
done = reward != 0 or self.t > self.timeout
|
|
40
|
-
return self._get_dfa_obs(), reward, done, False, {}
|
|
41
|
-
|
|
42
|
-
def _get_dfa_obs(self) -> np.ndarray:
|
|
43
|
-
dfa_obs = np.array([int(i) for i in str(self.dfa.to_int())])
|
|
44
|
-
obs = np.pad(dfa_obs, (self.size_bound - dfa_obs.shape[0], 0), constant_values=0)
|
|
45
|
-
return obs
|
dfa_gym-0.1.0.dist-info/METADATA
DELETED
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: dfa-gym
|
|
3
|
-
Version: 0.1.0
|
|
4
|
-
Summary: Gymnasium environment for solving DFAs and wrapping other environments with DFA goals
|
|
5
|
-
License-File: LICENSE
|
|
6
|
-
Requires-Python: >=3.12
|
|
7
|
-
Requires-Dist: dfa-samplers>=0.1.0
|
|
8
|
-
Requires-Dist: gymnasium>=1.0.0
|
|
9
|
-
Description-Content-Type: text/markdown
|
|
10
|
-
|
|
11
|
-
# dfa-gym
|
dfa_gym-0.1.0.dist-info/RECORD
DELETED
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
dfa_gym/__init__.py,sha256=tLY48NluNVv66znFnlR7j9o-pRW5caaO766W724__HY,364
|
|
2
|
-
dfa_gym/dfa_env.py,sha256=u-mOCPhXRljp2t-VmvDfsHbKIdHdn59UGomHxUa_BxQ,1601
|
|
3
|
-
dfa_gym/dfa_wrapper.py,sha256=11eqfyl6g2v-wILGMWLg9L2sMJYnwl5rMn11p9YbQF0,2283
|
|
4
|
-
dfa_gym-0.1.0.dist-info/METADATA,sha256=W_S2r-zMFEX9oeVJa3kCVCStnhQ91rNj-tA_FO6oBm8,309
|
|
5
|
-
dfa_gym-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
dfa_gym-0.1.0.dist-info/licenses/LICENSE,sha256=Cvu0BZqt3rcFFv70hcFDgD_y8ryOKW85F-qGRfYI4iM,1071
|
|
7
|
-
dfa_gym-0.1.0.dist-info/RECORD,,
|
|
File without changes
|