synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +575 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +104 -6
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,675 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Optional, Dict, Any, Tuple
|
3
|
+
|
4
|
+
from synth_ai.environments.environment.shared_engine import (
|
5
|
+
GetObservationCallable,
|
6
|
+
InternalObservation,
|
7
|
+
)
|
8
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
9
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
10
|
+
import numpy as np
|
11
|
+
from dataclasses import dataclass
|
12
|
+
from synth_ai.environments.examples.sokoban.taskset import (
|
13
|
+
SokobanTaskInstance,
|
14
|
+
) # Assuming this is where SokobanTaskInstance is defined
|
15
|
+
from synth_ai.environments.reproducibility.core import IReproducibleEngine # Added import
|
16
|
+
import logging
|
17
|
+
from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent
|
18
|
+
from synth_ai.environments.examples.sokoban.engine_helpers.vendored.envs.sokoban_env import (
|
19
|
+
ACTION_LOOKUP,
|
20
|
+
SokobanEnv as GymSokobanEnv,
|
21
|
+
)
|
22
|
+
import numpy as np
|
23
|
+
|
24
|
+
# No monkey-patch needed - we fixed the vendored code directly
|
25
|
+
|
26
|
+
# Configure logging for debugging SokobanEngine steps
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
logging.basicConfig(level=logging.DEBUG)
|
29
|
+
# Suppress verbose PIL debug logs
|
30
|
+
logging.getLogger("PIL").setLevel(logging.WARNING)
|
31
|
+
|
32
|
+
# --- Action Mapping ---
|
33
|
+
ACTION_STRING_TO_INT: Dict[str, int] = {
|
34
|
+
"no operation": 0,
|
35
|
+
"push up": 1,
|
36
|
+
"push down": 2,
|
37
|
+
"push left": 3,
|
38
|
+
"push right": 4,
|
39
|
+
"move up": 5,
|
40
|
+
"move down": 6,
|
41
|
+
"move left": 7,
|
42
|
+
"move right": 8,
|
43
|
+
}
|
44
|
+
INT_TO_ACTION_STRING: Dict[int, str] = {v: k for k, v in ACTION_STRING_TO_INT.items()}
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass
|
48
|
+
class SokobanEngineSnapshot(StatefulEngineSnapshot):
|
49
|
+
task_instance_dict: Dict
|
50
|
+
engine_snapshot: Dict
|
51
|
+
|
52
|
+
|
53
|
+
@dataclass
|
54
|
+
class SokobanPublicState:
|
55
|
+
dim_room: Tuple[int, int]
|
56
|
+
room_fixed: np.ndarray # numpy kinda sucks
|
57
|
+
room_state: np.ndarray
|
58
|
+
player_position: Tuple[int, int]
|
59
|
+
boxes_on_target: int
|
60
|
+
num_steps: int
|
61
|
+
max_steps: int
|
62
|
+
last_action_name: str
|
63
|
+
num_boxes: int
|
64
|
+
error_info: Optional[str] = None
|
65
|
+
|
66
|
+
def diff(self, prev_state: "SokobanPublicState") -> Dict[str, Any]:
|
67
|
+
changes: Dict[str, Any] = {}
|
68
|
+
for field in self.__dataclass_fields__: # type: ignore[attr-defined]
|
69
|
+
new_v, old_v = getattr(self, field), getattr(prev_state, field)
|
70
|
+
if isinstance(new_v, np.ndarray):
|
71
|
+
if not np.array_equal(new_v, old_v):
|
72
|
+
changes[field] = True
|
73
|
+
elif new_v != old_v:
|
74
|
+
changes[field] = (old_v, new_v)
|
75
|
+
return changes
|
76
|
+
|
77
|
+
@property
|
78
|
+
def room_text(self) -> str:
|
79
|
+
"""ASCII visualization of the room state"""
|
80
|
+
return _grid_to_text(self.room_state)
|
81
|
+
|
82
|
+
def to_dict(self) -> Dict[str, Any]:
|
83
|
+
"""Convert to dictionary with proper numpy array serialization."""
|
84
|
+
return {
|
85
|
+
"dim_room": self.dim_room,
|
86
|
+
"room_fixed": self.room_fixed.tolist(), # Convert numpy array to list
|
87
|
+
"room_state": self.room_state.tolist(), # Convert numpy array to list
|
88
|
+
"player_position": self.player_position,
|
89
|
+
"boxes_on_target": self.boxes_on_target,
|
90
|
+
"num_steps": self.num_steps,
|
91
|
+
"max_steps": self.max_steps,
|
92
|
+
"last_action_name": self.last_action_name,
|
93
|
+
"num_boxes": self.num_boxes,
|
94
|
+
"error_info": self.error_info,
|
95
|
+
}
|
96
|
+
|
97
|
+
def __repr__(self) -> str:
|
98
|
+
"""Safe string representation that avoids numpy array recursion."""
|
99
|
+
return f"SokobanPublicState(dim_room={self.dim_room}, num_steps={self.num_steps}, boxes_on_target={self.boxes_on_target})"
|
100
|
+
|
101
|
+
def __str__(self) -> str:
|
102
|
+
"""Safe string representation that avoids numpy array recursion."""
|
103
|
+
return self.__repr__()
|
104
|
+
|
105
|
+
|
106
|
+
@dataclass
|
107
|
+
class SokobanPrivateState:
|
108
|
+
reward_last: float
|
109
|
+
total_reward: float
|
110
|
+
terminated: bool
|
111
|
+
truncated: bool
|
112
|
+
rng_state: dict | None = None
|
113
|
+
|
114
|
+
def diff(self, prev_state: "SokobanPrivateState") -> Dict[str, Any]:
|
115
|
+
changes: Dict[str, Any] = {}
|
116
|
+
for field in self.__dataclass_fields__: # type: ignore[attr-defined]
|
117
|
+
new_v, old_v = getattr(self, field), getattr(prev_state, field)
|
118
|
+
if new_v != old_v:
|
119
|
+
changes[field] = (old_v, new_v)
|
120
|
+
return changes
|
121
|
+
|
122
|
+
def to_dict(self) -> Dict[str, Any]:
|
123
|
+
"""Convert to dictionary with proper serialization."""
|
124
|
+
return {
|
125
|
+
"reward_last": self.reward_last,
|
126
|
+
"total_reward": self.total_reward,
|
127
|
+
"terminated": self.terminated,
|
128
|
+
"truncated": self.truncated,
|
129
|
+
"rng_state": self.rng_state,
|
130
|
+
}
|
131
|
+
|
132
|
+
def __repr__(self) -> str:
|
133
|
+
"""Safe string representation."""
|
134
|
+
return f"SokobanPrivateState(reward_last={self.reward_last}, total_reward={self.total_reward}, terminated={self.terminated})"
|
135
|
+
|
136
|
+
def __str__(self) -> str:
|
137
|
+
"""Safe string representation."""
|
138
|
+
return self.__repr__()
|
139
|
+
|
140
|
+
|
141
|
+
# Note - just how we roll! Show your agent whatever state you want
|
142
|
+
# Close to original
|
143
|
+
def _grid_to_text(grid: np.ndarray) -> str:
|
144
|
+
"""Pretty 3-char glyphs for each cell – same lookup the legacy renderer used."""
|
145
|
+
return "\n".join(
|
146
|
+
"".join(GRID_LOOKUP.get(int(cell), "?") for cell in row) # type: ignore[arg-type]
|
147
|
+
for row in grid
|
148
|
+
)
|
149
|
+
|
150
|
+
|
151
|
+
class SynthSokobanObservationCallable(GetObservationCallable):
|
152
|
+
def __init__(self):
|
153
|
+
pass
|
154
|
+
|
155
|
+
async def get_observation(
|
156
|
+
self, pub: SokobanPublicState, priv: SokobanPrivateState
|
157
|
+
) -> InternalObservation: # type: ignore[override]
|
158
|
+
board_txt = _grid_to_text(pub.room_state)
|
159
|
+
return {
|
160
|
+
"room_text": board_txt,
|
161
|
+
"player_position": tuple(map(int, pub.player_position)),
|
162
|
+
"boxes_on_target": int(pub.boxes_on_target),
|
163
|
+
"steps_taken": int(pub.num_steps),
|
164
|
+
"max_steps": int(pub.max_steps),
|
165
|
+
"last_action": pub.last_action_name,
|
166
|
+
"reward_last": float(priv.reward_last),
|
167
|
+
"total_reward": float(priv.total_reward),
|
168
|
+
"terminated": bool(priv.terminated),
|
169
|
+
"truncated": bool(priv.truncated),
|
170
|
+
"num_boxes": int(pub.num_boxes),
|
171
|
+
}
|
172
|
+
|
173
|
+
|
174
|
+
# Close to original
|
175
|
+
class SynthSokobanCheckpointObservationCallable(GetObservationCallable):
|
176
|
+
"""
|
177
|
+
Snapshot emitted once after the episode finishes.
|
178
|
+
Mirrors the legacy 'final_observation' concept: full board + final tallies.
|
179
|
+
"""
|
180
|
+
|
181
|
+
def __init__(self):
|
182
|
+
pass
|
183
|
+
|
184
|
+
async def get_observation(
|
185
|
+
self, pub: SokobanPublicState, priv: SokobanPrivateState
|
186
|
+
) -> InternalObservation: # type: ignore[override]
|
187
|
+
board_txt = _grid_to_text(pub.room_state)
|
188
|
+
return {
|
189
|
+
"room_text_final": board_txt,
|
190
|
+
"boxes_on_target_final": int(pub.boxes_on_target),
|
191
|
+
"steps_taken_final": int(pub.num_steps),
|
192
|
+
"total_reward": float(priv.total_reward),
|
193
|
+
"terminated": bool(priv.terminated),
|
194
|
+
"truncated": bool(priv.truncated),
|
195
|
+
"num_boxes": int(pub.num_boxes),
|
196
|
+
}
|
197
|
+
|
198
|
+
|
199
|
+
# Think of engine as the actual logic, then with hooks to update the public and private state
|
200
|
+
# Note - I don't really want to split up the transformation/engine logic from the instance information. There's already quite a bit of abstraction, so let's make the hard call here. I observe that this class does combine the responsibility of tracking engine state AND containing dynamics, but I think it's fine.
|
201
|
+
|
202
|
+
|
203
|
+
GRID_LOOKUP = {0: " # ", 1: " _ ", 2: " O ", 3: " √ ", 4: " X ", 5: " P ", 6: " S "}
|
204
|
+
|
205
|
+
|
206
|
+
def _count_boxes_on_target(room_state: np.ndarray) -> int:
|
207
|
+
"""Return number of boxes currently sitting on target tiles."""
|
208
|
+
return int(np.sum(room_state == 3))
|
209
|
+
|
210
|
+
|
211
|
+
def package_sokoban_env_from_engine_snapshot(
|
212
|
+
engine_snapshot: Dict[str, Any],
|
213
|
+
) -> GymSokobanEnv:
|
214
|
+
"""Instantiate SokobanEnv and load every field from a saved-state dict."""
|
215
|
+
# 1. create empty env (skip reset)
|
216
|
+
env = GymSokobanEnv(
|
217
|
+
dim_room=tuple(engine_snapshot["dim_room"]),
|
218
|
+
max_steps=engine_snapshot.get("max_steps", 120),
|
219
|
+
num_boxes=engine_snapshot.get("num_boxes", 1),
|
220
|
+
reset=False,
|
221
|
+
)
|
222
|
+
|
223
|
+
# 2. restore core grids
|
224
|
+
env.room_fixed = np.asarray(engine_snapshot["room_fixed"], dtype=int)
|
225
|
+
env.room_state = np.asarray(engine_snapshot["room_state"], dtype=int)
|
226
|
+
|
227
|
+
# 3. restore auxiliary data
|
228
|
+
raw_map = engine_snapshot.get("box_mapping", {})
|
229
|
+
if isinstance(raw_map, list): # list-of-dict form
|
230
|
+
env.box_mapping = {tuple(e["original"]): tuple(e["current"]) for e in raw_map}
|
231
|
+
else: # string-keyed dict form
|
232
|
+
env.box_mapping = {
|
233
|
+
tuple(map(int, k.strip("[]").split(","))): tuple(v) for k, v in raw_map.items()
|
234
|
+
}
|
235
|
+
|
236
|
+
env.player_position = np.argwhere(env.room_state == 5)[0]
|
237
|
+
env.num_env_steps = engine_snapshot.get("num_env_steps", 0)
|
238
|
+
env.boxes_on_target = engine_snapshot.get("boxes_on_target", int(np.sum(env.room_state == 3)))
|
239
|
+
env.reward_last = engine_snapshot.get("reward_last", 0)
|
240
|
+
|
241
|
+
# 4. restore RNG (if provided)
|
242
|
+
rng = engine_snapshot.get("np_random_state")
|
243
|
+
if rng:
|
244
|
+
env.seed() # init env.np_random
|
245
|
+
env.np_random.set_state(
|
246
|
+
(
|
247
|
+
rng["key"],
|
248
|
+
np.array(rng["state"], dtype=np.uint32),
|
249
|
+
rng["pos"],
|
250
|
+
0, # has_gauss
|
251
|
+
0.0, # cached_gaussian
|
252
|
+
)
|
253
|
+
)
|
254
|
+
|
255
|
+
return env
|
256
|
+
|
257
|
+
|
258
|
+
# --- Reward Components ---
|
259
|
+
class SokobanGoalAchievedComponent(RewardComponent):
|
260
|
+
async def score(self, state: "SokobanPublicState", action: Any) -> float:
|
261
|
+
if state.boxes_on_target == state.num_boxes:
|
262
|
+
return 1.0
|
263
|
+
return 0.0
|
264
|
+
|
265
|
+
|
266
|
+
class SokobanStepPenaltyComponent(RewardComponent):
|
267
|
+
def __init__(self, penalty: float = -0.01):
|
268
|
+
super().__init__()
|
269
|
+
self.penalty = penalty
|
270
|
+
self.weight = 1.0
|
271
|
+
|
272
|
+
async def score(self, state: Any, action: Any) -> float:
|
273
|
+
return self.penalty
|
274
|
+
|
275
|
+
|
276
|
+
class SokobanEngine(StatefulEngine, IReproducibleEngine):
|
277
|
+
task_instance: TaskInstance
|
278
|
+
package_sokoban_env: GymSokobanEnv
|
279
|
+
|
280
|
+
# sokoban stuff
|
281
|
+
|
282
|
+
def __init__(self, task_instance: TaskInstance):
|
283
|
+
self.task_instance = task_instance
|
284
|
+
self._total_reward = 0.0 # Initialize total_reward
|
285
|
+
self._current_action_for_reward: Optional[int] = None
|
286
|
+
self.reward_stack = RewardStack(
|
287
|
+
components=[
|
288
|
+
SokobanGoalAchievedComponent(),
|
289
|
+
SokobanStepPenaltyComponent(penalty=-0.01),
|
290
|
+
]
|
291
|
+
)
|
292
|
+
|
293
|
+
init_snap: dict | None = getattr(self.task_instance, "initial_engine_snapshot", None)
|
294
|
+
|
295
|
+
if init_snap:
|
296
|
+
# Initialize package_sokoban_env here using the snapshot
|
297
|
+
self.package_sokoban_env = package_sokoban_env_from_engine_snapshot(init_snap)
|
298
|
+
# Ensure counters are consistent with the snapshot state
|
299
|
+
self.package_sokoban_env.boxes_on_target = _count_boxes_on_target(
|
300
|
+
self.package_sokoban_env.room_state
|
301
|
+
)
|
302
|
+
else:
|
303
|
+
# No initial snapshot - this should not happen with the new pre-generated puzzle system
|
304
|
+
# Create a minimal default environment as fallback
|
305
|
+
logger.warning(
|
306
|
+
"No initial_engine_snapshot provided - this should not happen with verified puzzles"
|
307
|
+
)
|
308
|
+
self.package_sokoban_env = GymSokobanEnv(
|
309
|
+
dim_room=(5, 5),
|
310
|
+
max_steps=50,
|
311
|
+
num_boxes=1,
|
312
|
+
reset=False, # Don't reset during creation to avoid generation
|
313
|
+
)
|
314
|
+
|
315
|
+
# gives the observation!
|
316
|
+
# also final rewards when those are passed in
|
317
|
+
async def _render(
|
318
|
+
self,
|
319
|
+
private_state: SokobanPrivateState,
|
320
|
+
public_state: SokobanPublicState,
|
321
|
+
get_observation: Optional[GetObservationCallable] = None,
|
322
|
+
) -> str:
|
323
|
+
"""
|
324
|
+
1. choose the observation callable (default = SynthSokobanObservationCallable)
|
325
|
+
2. fetch obs via callable(pub, priv)
|
326
|
+
3. if callable returned a dict -> pretty-print board + footer
|
327
|
+
if str -> forward unchanged
|
328
|
+
"""
|
329
|
+
# 1 – pick callable
|
330
|
+
obs_cb = get_observation or SynthSokobanObservationCallable()
|
331
|
+
|
332
|
+
# 2 – pull observation
|
333
|
+
obs = await obs_cb.get_observation(public_state, private_state)
|
334
|
+
|
335
|
+
# 3 – stringify
|
336
|
+
if isinstance(obs, str):
|
337
|
+
return obs
|
338
|
+
|
339
|
+
if isinstance(obs, dict):
|
340
|
+
board_txt = (
|
341
|
+
obs.get("room_text")
|
342
|
+
or obs.get("room_text_final")
|
343
|
+
or _grid_to_text(public_state.room_state)
|
344
|
+
)
|
345
|
+
footer = (
|
346
|
+
f"steps: {public_state.num_steps}/{public_state.max_steps} | "
|
347
|
+
f"boxes✓: {public_state.boxes_on_target} | "
|
348
|
+
f"last_r: {private_state.reward_last:.2f} | "
|
349
|
+
f"total_r: {private_state.total_reward:.2f}"
|
350
|
+
)
|
351
|
+
return f"{board_txt}\n{footer}"
|
352
|
+
|
353
|
+
# unknown payload type -> fallback
|
354
|
+
return str(obs)
|
355
|
+
|
356
|
+
# yields private state, public state
|
357
|
+
async def _step_engine(self, action: int) -> Tuple[SokobanPrivateState, SokobanPublicState]:
|
358
|
+
self._current_action_for_reward = action # Set context for reward
|
359
|
+
|
360
|
+
# --- Run underlying package environment step ---
|
361
|
+
# The raw reward from package_sokoban_env.step() will be ignored,
|
362
|
+
# as we are now using our RewardStack for a more structured reward calculation.
|
363
|
+
obs_raw, _, terminated_gym, info = self.package_sokoban_env.step(action)
|
364
|
+
|
365
|
+
self.package_sokoban_env.boxes_on_target = _count_boxes_on_target(
|
366
|
+
self.package_sokoban_env.room_state
|
367
|
+
)
|
368
|
+
current_pub_state = SokobanPublicState(
|
369
|
+
dim_room=self.package_sokoban_env.dim_room,
|
370
|
+
room_fixed=self.package_sokoban_env.room_fixed.copy(),
|
371
|
+
room_state=self.package_sokoban_env.room_state.copy(),
|
372
|
+
player_position=tuple(self.package_sokoban_env.player_position),
|
373
|
+
boxes_on_target=self.package_sokoban_env.boxes_on_target,
|
374
|
+
num_steps=self.package_sokoban_env.num_env_steps,
|
375
|
+
max_steps=self.package_sokoban_env.max_steps,
|
376
|
+
last_action_name=ACTION_LOOKUP.get(action, "Unknown"),
|
377
|
+
num_boxes=self.package_sokoban_env.num_boxes,
|
378
|
+
)
|
379
|
+
|
380
|
+
# --- Calculate reward using RewardStack ---
|
381
|
+
# The 'state' for reward components is current_pub_state.
|
382
|
+
# The 'action' for reward components is the raw agent action.
|
383
|
+
reward_from_stack = await self.reward_stack.step_reward(
|
384
|
+
state=current_pub_state, action=self._current_action_for_reward
|
385
|
+
)
|
386
|
+
self._current_action_for_reward = None # Reset context
|
387
|
+
|
388
|
+
self._total_reward += reward_from_stack
|
389
|
+
# Update reward_last on the package_sokoban_env if it's used by its internal logic or for direct inspection.
|
390
|
+
# However, the authoritative reward for our framework is reward_from_stack.
|
391
|
+
self.package_sokoban_env.reward_last = reward_from_stack
|
392
|
+
|
393
|
+
# --- Determine terminated and truncated status based on gym env and game logic ---
|
394
|
+
solved = self.package_sokoban_env.boxes_on_target == self.package_sokoban_env.num_boxes
|
395
|
+
terminated = terminated_gym or solved # terminated_gym from underlying env, or solved state
|
396
|
+
# If underlying env says terminated due to max_steps, it is truncation for us.
|
397
|
+
# If solved, it's termination. Otherwise, depends on max_steps.
|
398
|
+
truncated = (
|
399
|
+
self.package_sokoban_env.num_env_steps >= self.package_sokoban_env.max_steps
|
400
|
+
) and not solved
|
401
|
+
if solved:
|
402
|
+
terminated = True # Ensure solved always terminates
|
403
|
+
truncated = False # Cannot be truncated if solved
|
404
|
+
|
405
|
+
priv = SokobanPrivateState(
|
406
|
+
reward_last=reward_from_stack,
|
407
|
+
total_reward=self._total_reward,
|
408
|
+
terminated=terminated,
|
409
|
+
truncated=truncated,
|
410
|
+
)
|
411
|
+
return priv, current_pub_state
|
412
|
+
|
413
|
+
async def _reset_engine(
|
414
|
+
self, *, seed: int | None = None
|
415
|
+
) -> Tuple[SokobanPrivateState, SokobanPublicState]:
|
416
|
+
"""
|
417
|
+
(Re)build the wrapped PackageSokobanEnv in a fresh state.
|
418
|
+
|
419
|
+
1. Decide whether we have an initial snapshot in the TaskInstance.
|
420
|
+
2. If yes → hydrate env from it; otherwise call env.reset(seed).
|
421
|
+
3. Zero-out cumulative reward and emit fresh state objects.
|
422
|
+
"""
|
423
|
+
self._total_reward = 0.0
|
424
|
+
self._current_action_for_reward = None
|
425
|
+
|
426
|
+
init_snap: dict | None = getattr(self.task_instance, "initial_engine_snapshot")
|
427
|
+
|
428
|
+
if init_snap:
|
429
|
+
self.package_sokoban_env = package_sokoban_env_from_engine_snapshot(init_snap)
|
430
|
+
# ensure counter correct even if snapshot was stale
|
431
|
+
self.package_sokoban_env.boxes_on_target = _count_boxes_on_target(
|
432
|
+
self.package_sokoban_env.room_state
|
433
|
+
)
|
434
|
+
else:
|
435
|
+
# No initial snapshot - this should not happen with the new pre-generated puzzle system
|
436
|
+
logger.warning(
|
437
|
+
"No initial_engine_snapshot provided during reset - this should not happen with verified puzzles"
|
438
|
+
)
|
439
|
+
# Simple fallback: try to reset the existing environment
|
440
|
+
try:
|
441
|
+
_ = self.package_sokoban_env.reset(seed=seed)
|
442
|
+
# Update the boxes_on_target counter
|
443
|
+
self.package_sokoban_env.boxes_on_target = _count_boxes_on_target(
|
444
|
+
self.package_sokoban_env.room_state
|
445
|
+
)
|
446
|
+
except Exception as e:
|
447
|
+
logger.error(f"Failed to reset environment: {e}")
|
448
|
+
raise RuntimeError(
|
449
|
+
"Environment reset failed. This should not happen with verified puzzles. "
|
450
|
+
"Ensure task instances have initial_engine_snapshot."
|
451
|
+
) from e
|
452
|
+
|
453
|
+
# build first public/private views
|
454
|
+
priv = SokobanPrivateState(
|
455
|
+
reward_last=0.0,
|
456
|
+
total_reward=0.0,
|
457
|
+
terminated=False,
|
458
|
+
truncated=False,
|
459
|
+
rng_state=self.package_sokoban_env.np_random.bit_generator.state,
|
460
|
+
)
|
461
|
+
pub = SokobanPublicState(
|
462
|
+
dim_room=self.package_sokoban_env.dim_room,
|
463
|
+
room_fixed=self.package_sokoban_env.room_fixed.copy(),
|
464
|
+
room_state=self.package_sokoban_env.room_state.copy(),
|
465
|
+
player_position=tuple(self.package_sokoban_env.player_position),
|
466
|
+
boxes_on_target=self.package_sokoban_env.boxes_on_target,
|
467
|
+
num_steps=self.package_sokoban_env.num_env_steps,
|
468
|
+
max_steps=self.package_sokoban_env.max_steps,
|
469
|
+
last_action_name="Initial",
|
470
|
+
num_boxes=self.package_sokoban_env.num_boxes,
|
471
|
+
)
|
472
|
+
return priv, pub
|
473
|
+
|
474
|
+
async def _serialize_engine(self) -> SokobanEngineSnapshot:
|
475
|
+
"""Dump wrapped env + task_instance into a JSON-ready snapshot."""
|
476
|
+
env = self.package_sokoban_env
|
477
|
+
|
478
|
+
# helper – numpy RNG → dict
|
479
|
+
def _rng_state(e):
|
480
|
+
state = e.np_random.bit_generator.state
|
481
|
+
state["state"] = state["state"].tolist()
|
482
|
+
return state
|
483
|
+
|
484
|
+
snap: Dict[str, Any] = {
|
485
|
+
"dim_room": list(env.dim_room),
|
486
|
+
"max_steps": env.max_steps,
|
487
|
+
"num_boxes": env.num_boxes,
|
488
|
+
"room_fixed": env.room_fixed.tolist(),
|
489
|
+
"room_state": env.room_state.tolist(),
|
490
|
+
"box_mapping": [
|
491
|
+
{"original": list(k), "current": list(v)} for k, v in env.box_mapping.items()
|
492
|
+
],
|
493
|
+
"player_position": env.player_position.tolist(),
|
494
|
+
"num_env_steps": env.num_env_steps,
|
495
|
+
"boxes_on_target": env.boxes_on_target,
|
496
|
+
"reward_last": env.reward_last,
|
497
|
+
"total_reward": getattr(self, "_total_reward", 0.0),
|
498
|
+
# "np_random_state": _rng_state(env), # Assuming _rng_state is defined if needed
|
499
|
+
}
|
500
|
+
|
501
|
+
# Serialize the TaskInstance using its own serialize method
|
502
|
+
task_instance_dict = await self.task_instance.serialize()
|
503
|
+
|
504
|
+
return SokobanEngineSnapshot(
|
505
|
+
task_instance_dict=task_instance_dict, # Store serialized TaskInstance
|
506
|
+
engine_snapshot=snap,
|
507
|
+
)
|
508
|
+
|
509
|
+
@classmethod
|
510
|
+
async def _deserialize_engine(
|
511
|
+
cls, sokoban_engine_snapshot: "SokobanEngineSnapshot"
|
512
|
+
) -> "SokobanEngine":
|
513
|
+
"""
|
514
|
+
Recreate a SokobanEngine (including wrapped env and TaskInstance) from a snapshot blob.
|
515
|
+
"""
|
516
|
+
# --- 1. rebuild TaskInstance ----------------------------------- #
|
517
|
+
# Use the concrete SokobanTaskInstance.deserialize method
|
518
|
+
instance = await SokobanTaskInstance.deserialize(sokoban_engine_snapshot.task_instance_dict)
|
519
|
+
|
520
|
+
# --- 2. create engine shell ------------------------------------ #
|
521
|
+
engine = cls.__new__(cls) # bypass __init__
|
522
|
+
StatefulEngine.__init__(engine) # initialise mix-in parts
|
523
|
+
engine.task_instance = instance # assign restored TaskInstance
|
524
|
+
|
525
|
+
# --- 3. initialize attributes that are normally set in __init__ --- #
|
526
|
+
engine._current_action_for_reward = None
|
527
|
+
engine.reward_stack = RewardStack(
|
528
|
+
components=[
|
529
|
+
SokobanGoalAchievedComponent(),
|
530
|
+
SokobanStepPenaltyComponent(penalty=-0.01),
|
531
|
+
]
|
532
|
+
)
|
533
|
+
|
534
|
+
# --- 4. hydrate env & counters --------------------------------- #
|
535
|
+
engine.package_sokoban_env = package_sokoban_env_from_engine_snapshot(
|
536
|
+
sokoban_engine_snapshot.engine_snapshot
|
537
|
+
)
|
538
|
+
engine._total_reward = sokoban_engine_snapshot.engine_snapshot.get("total_reward", 0.0)
|
539
|
+
return engine
|
540
|
+
|
541
|
+
def get_current_states_for_observation(
|
542
|
+
self,
|
543
|
+
) -> Tuple[SokobanPrivateState, SokobanPublicState]:
|
544
|
+
# Helper to get current state without advancing engine, useful for error in Environment.step
|
545
|
+
terminated = bool(
|
546
|
+
self.package_sokoban_env.boxes_on_target == self.package_sokoban_env.num_boxes
|
547
|
+
)
|
548
|
+
truncated = bool(
|
549
|
+
self.package_sokoban_env.num_env_steps >= self.package_sokoban_env.max_steps
|
550
|
+
)
|
551
|
+
priv = SokobanPrivateState(
|
552
|
+
reward_last=self.package_sokoban_env.reward_last, # Last known reward
|
553
|
+
total_reward=self._total_reward,
|
554
|
+
terminated=terminated,
|
555
|
+
truncated=truncated,
|
556
|
+
)
|
557
|
+
pub = SokobanPublicState(
|
558
|
+
dim_room=self.package_sokoban_env.dim_room,
|
559
|
+
room_fixed=self.package_sokoban_env.room_fixed.copy(),
|
560
|
+
room_state=self.package_sokoban_env.room_state.copy(),
|
561
|
+
player_position=tuple(self.package_sokoban_env.player_position),
|
562
|
+
boxes_on_target=self.package_sokoban_env.boxes_on_target,
|
563
|
+
num_steps=self.package_sokoban_env.num_env_steps,
|
564
|
+
max_steps=self.package_sokoban_env.max_steps,
|
565
|
+
last_action_name=ACTION_LOOKUP.get(
|
566
|
+
getattr(self.package_sokoban_env, "last_action", -1), "Initial"
|
567
|
+
),
|
568
|
+
num_boxes=self.package_sokoban_env.num_boxes,
|
569
|
+
)
|
570
|
+
return priv, pub
|
571
|
+
|
572
|
+
|
573
|
+
if __name__ == "__main__":
|
574
|
+
# // 0=wall, 1=floor, 2=target
|
575
|
+
# // 4=box-not-on-target, 5=player
|
576
|
+
# initial_room = {
|
577
|
+
# "dim_room": [5, 5],
|
578
|
+
# "max_steps": 120,
|
579
|
+
# "num_boxes": 1,
|
580
|
+
# "seed": 42,
|
581
|
+
# "room_fixed": [
|
582
|
+
# [0, 0, 0, 0, 0],
|
583
|
+
# [0, 1, 1, 2, 0],
|
584
|
+
# [0, 1, 0, 1, 0],
|
585
|
+
# [0, 1, 5, 1, 0],
|
586
|
+
# [0, 0, 0, 0, 0]
|
587
|
+
# ],
|
588
|
+
# "room_state": [
|
589
|
+
# [0, 0, 0, 0, 0],
|
590
|
+
# [0, 1, 4, 1, 0],
|
591
|
+
# [0, 1, 0, 1, 0],
|
592
|
+
# [0, 1, 5, 1, 0],
|
593
|
+
# [0, 0, 0, 0, 0]
|
594
|
+
# ]
|
595
|
+
# }
|
596
|
+
task_instance_dict = {
|
597
|
+
"initial_engine_snapshot": {
|
598
|
+
"dim_room": [5, 5],
|
599
|
+
"max_steps": 120,
|
600
|
+
"num_boxes": 1,
|
601
|
+
"room_fixed": [
|
602
|
+
[0, 0, 0, 0, 0],
|
603
|
+
[0, 1, 1, 2, 0],
|
604
|
+
[0, 1, 0, 1, 0],
|
605
|
+
[0, 1, 1, 1, 0],
|
606
|
+
[0, 0, 0, 0, 0],
|
607
|
+
],
|
608
|
+
"room_state": [
|
609
|
+
[0, 0, 0, 0, 0],
|
610
|
+
[0, 1, 4, 1, 0],
|
611
|
+
[0, 1, 0, 1, 0],
|
612
|
+
[0, 1, 5, 1, 0],
|
613
|
+
[0, 0, 0, 0, 0],
|
614
|
+
],
|
615
|
+
"box_mapping": [{"original": [1, 2], "current": [3, 2]}],
|
616
|
+
"boxes_on_target": 0,
|
617
|
+
"np_random_state": {
|
618
|
+
"key": "MT19937",
|
619
|
+
"state": [1804289383, 846930886, 1681692777, 1714636915],
|
620
|
+
"pos": 0,
|
621
|
+
},
|
622
|
+
"reward_last": 0,
|
623
|
+
"num_env_steps": 0,
|
624
|
+
}
|
625
|
+
}
|
626
|
+
import random
|
627
|
+
import asyncio
|
628
|
+
|
629
|
+
async def sanity():
|
630
|
+
task_instance = TaskInstance()
|
631
|
+
engine = SokobanEngine(task_instance=task_instance)
|
632
|
+
priv, pub = await engine._reset_engine()
|
633
|
+
print(await engine._render(priv, pub)) # initial board
|
634
|
+
|
635
|
+
for _ in range(10): # play 10 random moves
|
636
|
+
a = random.randint(0, 8) # action range 0-8
|
637
|
+
priv, pub = await engine._step_engine(a)
|
638
|
+
print(f"\n### step {pub.num_steps} — {ACTION_LOOKUP[a]} ###")
|
639
|
+
print("public:", pub)
|
640
|
+
print("private:", priv)
|
641
|
+
print(await engine._render(priv, pub))
|
642
|
+
if priv.terminated or priv.truncated:
|
643
|
+
break
|
644
|
+
|
645
|
+
asyncio.run(sanity())
|
646
|
+
# sokoban_engine = SokobanEngine.deserialize(
|
647
|
+
# engine_snapshot=SokobanEngineSnapshot(
|
648
|
+
# instance=instance_information,
|
649
|
+
# snapshot_dict=instance_information["initial_engine_snapshot"],
|
650
|
+
# )
|
651
|
+
# )
|
652
|
+
|
653
|
+
|
654
|
+
# {
|
655
|
+
# "dim_room": [5, 5],
|
656
|
+
# "max_steps": 120,
|
657
|
+
# "num_boxes": 1,
|
658
|
+
|
659
|
+
# "room_fixed": [...], // as above
|
660
|
+
# "room_state": [...], // current grid (3 = box-on-target)
|
661
|
+
|
662
|
+
# "box_mapping": {
|
663
|
+
# "[1,3]": [3,2] // origin-target → current-pos pairs
|
664
|
+
# },
|
665
|
+
# "player_position": [3, 2], // row, col
|
666
|
+
|
667
|
+
# "num_env_steps": 15, // steps already taken
|
668
|
+
# "boxes_on_target": 0, // live counter
|
669
|
+
|
670
|
+
# "np_random_state": { // optional but makes replay bit-exact
|
671
|
+
# "key": "MT19937",
|
672
|
+
# "state": [1804289383, 846930886, ...],
|
673
|
+
# "pos": 123
|
674
|
+
# }
|
675
|
+
# }
|
@@ -0,0 +1 @@
|
|
1
|
+
# Engine helpers for Sokoban
|