synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +579 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/wordle/__init__.py +29 -0
- synth_ai/environments/examples/wordle/engine.py +391 -0
- synth_ai/environments/examples/wordle/environment.py +154 -0
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
- synth_ai/environments/examples/wordle/taskset.py +222 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/environments/service/core_routes.py +38 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
- synth_ai/learning/prompts/mipro.py +273 -1
- synth_ai/learning/prompts/random_search.py +247 -0
- synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
- synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
- synth_ai/lm/injection.py +81 -0
- synth_ai/lm/overrides.py +204 -0
- synth_ai/lm/provider_support/anthropic.py +39 -12
- synth_ai/lm/provider_support/openai.py +31 -4
- synth_ai/lm/vendors/core/anthropic_api.py +16 -0
- synth_ai/lm/vendors/openai_standard.py +35 -5
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,228 @@
|
|
1
|
+
from typing import List, Optional, Any, Dict, Union
|
2
|
+
from pydantic import BaseModel
|
3
|
+
import dataclasses
|
4
|
+
|
5
|
+
from synth_ai.environments.examples.sokoban.engine import (
|
6
|
+
SokobanEngine,
|
7
|
+
SynthSokobanObservationCallable,
|
8
|
+
SokobanPrivateState,
|
9
|
+
SokobanPublicState,
|
10
|
+
SynthSokobanCheckpointObservationCallable,
|
11
|
+
SokobanEngineSnapshot,
|
12
|
+
)
|
13
|
+
from synth_ai.environments.environment.shared_engine import (
|
14
|
+
GetObservationCallable,
|
15
|
+
InternalObservation,
|
16
|
+
)
|
17
|
+
from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
|
18
|
+
from synth_ai.environments.stateful.core import StatefulEnvironment
|
19
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
20
|
+
from synth_ai.environments.environment.tools import (
|
21
|
+
AbstractTool,
|
22
|
+
EnvToolCall,
|
23
|
+
ToolResult,
|
24
|
+
TOOL_REGISTRY,
|
25
|
+
register_tool,
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
# --- Tool Definition ---
|
30
|
+
class SokobanActionInput(BaseModel):
|
31
|
+
action: int
|
32
|
+
|
33
|
+
|
34
|
+
class SokobanInteractTool(AbstractTool):
|
35
|
+
name = "interact"
|
36
|
+
description = "Performs an action (e.g., move) in the Sokoban environment."
|
37
|
+
call_schema = SokobanActionInput
|
38
|
+
result_schema = ToolResult
|
39
|
+
|
40
|
+
def __init__(self, engine: SokobanEngine):
|
41
|
+
self.engine = engine
|
42
|
+
|
43
|
+
async def __call__(self, call: EnvToolCall) -> ToolResult:
|
44
|
+
try:
|
45
|
+
validated_args = self.call_schema(**call.args)
|
46
|
+
priv_state, pub_state = await self.engine._step_engine(validated_args.action)
|
47
|
+
return ToolResult(
|
48
|
+
ok=True,
|
49
|
+
payload={
|
50
|
+
"public": pub_state.to_dict(),
|
51
|
+
"private": priv_state.to_dict(),
|
52
|
+
},
|
53
|
+
)
|
54
|
+
except Exception as e:
|
55
|
+
# Add current public state to payload for context in case of error
|
56
|
+
_, pub_state_on_error = self.engine.get_current_states_for_observation()
|
57
|
+
return ToolResult(
|
58
|
+
ok=False,
|
59
|
+
error=str(e),
|
60
|
+
payload={"public": pub_state_on_error.to_dict()},
|
61
|
+
)
|
62
|
+
|
63
|
+
|
64
|
+
class SokobanEnvironment(StatefulEnvironment, ReproducibleEnvironment[SokobanEngine]):
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
task_instance: TaskInstance,
|
68
|
+
custom_step_obs: Optional[GetObservationCallable] = None,
|
69
|
+
custom_ckpt_obs: Optional[GetObservationCallable] = None,
|
70
|
+
):
|
71
|
+
self.name = "Sokoban"
|
72
|
+
self.task_instance = task_instance
|
73
|
+
# Default to SynthSokobanObservationCallable if none provided
|
74
|
+
self.custom_step_observation_callable = custom_step_obs or SynthSokobanObservationCallable()
|
75
|
+
self.custom_checkpoint_observation_callable = (
|
76
|
+
custom_ckpt_obs or SynthSokobanCheckpointObservationCallable()
|
77
|
+
)
|
78
|
+
self.engine: SokobanEngine = SokobanEngine(task_instance)
|
79
|
+
|
80
|
+
self._interact_tool = SokobanInteractTool(self.engine)
|
81
|
+
if self._interact_tool.name not in TOOL_REGISTRY:
|
82
|
+
register_tool(self._interact_tool)
|
83
|
+
# elif getattr(TOOL_REGISTRY[self._interact_tool.name], 'engine', None) is not self.engine:
|
84
|
+
# register_tool(self._interact_tool) # More robust check if tool has engine attr
|
85
|
+
|
86
|
+
async def initialize(self) -> InternalObservation:
|
87
|
+
priv, pub = await self.engine._reset_engine()
|
88
|
+
return await self._to_observation(priv, pub, self.custom_step_observation_callable)
|
89
|
+
|
90
|
+
async def terminate(self) -> InternalObservation:
|
91
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
92
|
+
priv.terminated = True # Mark as terminated
|
93
|
+
obs_dict = {"terminated": True, "message": "Environment terminated."}
|
94
|
+
# Use _to_observation to format, including final state
|
95
|
+
return await self._to_observation(
|
96
|
+
priv, pub, self.custom_step_observation_callable, extra_obs=obs_dict
|
97
|
+
)
|
98
|
+
|
99
|
+
def validate_tool_calls(
|
100
|
+
self,
|
101
|
+
tool_calls: Union[
|
102
|
+
EnvToolCall,
|
103
|
+
List[Dict[str, Any]],
|
104
|
+
List[List[Dict[str, Any]]],
|
105
|
+
Dict[str, Any],
|
106
|
+
],
|
107
|
+
) -> EnvToolCall:
|
108
|
+
# Normalize and validate to a single EnvToolCall
|
109
|
+
raw_call_data: Dict[str, Any]
|
110
|
+
if isinstance(tool_calls, list):
|
111
|
+
if not tool_calls:
|
112
|
+
raise ValueError("Received empty list of tool calls.")
|
113
|
+
first_item = tool_calls[0]
|
114
|
+
if isinstance(first_item, list):
|
115
|
+
if not first_item:
|
116
|
+
raise ValueError("Received empty inner list of tool calls.")
|
117
|
+
raw_call_data = first_item[0]
|
118
|
+
elif isinstance(first_item, dict):
|
119
|
+
raw_call_data = first_item
|
120
|
+
elif isinstance(first_item, EnvToolCall): # Already an EnvToolCall instance
|
121
|
+
agent_call = first_item # Assuming direct single call if already instance
|
122
|
+
if agent_call.tool != "interact":
|
123
|
+
raise ValueError(f"Unknown tool: {agent_call.tool}. Expected 'interact'.")
|
124
|
+
return agent_call
|
125
|
+
else:
|
126
|
+
raise TypeError(f"Unexpected type in tool_calls list: {type(first_item)}")
|
127
|
+
elif isinstance(tool_calls, dict): # Single call passed as dict
|
128
|
+
raw_call_data = tool_calls
|
129
|
+
elif isinstance(tool_calls, EnvToolCall): # Single call already an instance
|
130
|
+
if tool_calls.tool != "interact":
|
131
|
+
raise ValueError(f"Unknown tool: {tool_calls.tool}. Expected 'interact'.")
|
132
|
+
return tool_calls
|
133
|
+
else:
|
134
|
+
raise TypeError(f"Unexpected type for tool_calls: {type(tool_calls)}")
|
135
|
+
|
136
|
+
if not isinstance(raw_call_data, dict):
|
137
|
+
raise TypeError(f"Processed call data is not a dict: {type(raw_call_data)}")
|
138
|
+
|
139
|
+
# Convert dict to EnvToolCall instance
|
140
|
+
tool_name = raw_call_data.get("tool")
|
141
|
+
tool_args = raw_call_data.get("args", {})
|
142
|
+
if tool_name != "interact":
|
143
|
+
raise ValueError(f"Unknown tool: {tool_name}. Expected 'interact'.")
|
144
|
+
|
145
|
+
agent_call = EnvToolCall(tool=tool_name, args=tool_args)
|
146
|
+
return agent_call
|
147
|
+
|
148
|
+
async def step(
|
149
|
+
self,
|
150
|
+
tool_calls: Union[
|
151
|
+
EnvToolCall,
|
152
|
+
List[Dict[str, Any]],
|
153
|
+
List[List[Dict[str, Any]]],
|
154
|
+
Dict[str, Any],
|
155
|
+
],
|
156
|
+
) -> InternalObservation:
|
157
|
+
agent_call = self.validate_tool_calls(tool_calls)
|
158
|
+
tool_result: ToolResult = await self._interact_tool(agent_call)
|
159
|
+
|
160
|
+
payload_dict = tool_result.payload
|
161
|
+
if not tool_result.ok or not isinstance(payload_dict, dict): # Check tool_result.ok
|
162
|
+
# Fallback if payload isn't as expected or tool reported an error
|
163
|
+
priv_state, pub_state = self.engine.get_current_states_for_observation()
|
164
|
+
if tool_result.error and hasattr(pub_state, "error_info"):
|
165
|
+
pub_state.error_info = tool_result.error
|
166
|
+
else:
|
167
|
+
# This block assumes tool_result.ok is True and payload is a dict
|
168
|
+
priv_dict = payload_dict.get("private")
|
169
|
+
pub_dict = payload_dict.get("public")
|
170
|
+
|
171
|
+
if priv_dict is None or pub_dict is None:
|
172
|
+
# This case should ideally not happen if tool_result.ok is True
|
173
|
+
# and the tool is well-behaved, but as a safeguard:
|
174
|
+
priv_state, pub_state = self.engine.get_current_states_for_observation()
|
175
|
+
if tool_result.error and hasattr(
|
176
|
+
pub_state, "error_info"
|
177
|
+
): # Apply error even in this sub-optimal case
|
178
|
+
pub_state.error_info = tool_result.error
|
179
|
+
else:
|
180
|
+
priv_state = SokobanPrivateState(**priv_dict)
|
181
|
+
pub_state = SokobanPublicState(**pub_dict)
|
182
|
+
if tool_result.error and hasattr(pub_state, "error_info"):
|
183
|
+
pub_state.error_info = tool_result.error
|
184
|
+
|
185
|
+
return await self._to_observation(
|
186
|
+
priv_state, pub_state, self.custom_step_observation_callable
|
187
|
+
)
|
188
|
+
|
189
|
+
async def checkpoint(self) -> InternalObservation:
|
190
|
+
engine_snapshot: SokobanEngineSnapshot = await self.engine._serialize_engine()
|
191
|
+
# For checkpoint, we might want to convey the snapshot data differently.
|
192
|
+
# The existing _to_observation expects live priv/pub states.
|
193
|
+
# For now, using current live states for observation, plus snapshot.
|
194
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
195
|
+
obs_data = await self._to_observation(
|
196
|
+
priv, pub, self.custom_checkpoint_observation_callable
|
197
|
+
)
|
198
|
+
if isinstance(obs_data, dict):
|
199
|
+
obs_data["engine_snapshot_data"] = (
|
200
|
+
engine_snapshot.model_dump()
|
201
|
+
) # Add snapshot if obs is dict
|
202
|
+
return obs_data
|
203
|
+
|
204
|
+
async def _to_observation(
|
205
|
+
self,
|
206
|
+
priv: SokobanPrivateState,
|
207
|
+
pub: SokobanPublicState,
|
208
|
+
obs_cb: Optional[GetObservationCallable],
|
209
|
+
extra_obs: Optional[Dict[str, Any]] = None, # For adding things like termination messages
|
210
|
+
) -> InternalObservation:
|
211
|
+
# Ensure obs_cb is not None; use a default if necessary (though __init__ sets one)
|
212
|
+
active_obs_cb = obs_cb or SynthSokobanObservationCallable()
|
213
|
+
observation = await active_obs_cb.get_observation(pub, priv)
|
214
|
+
if extra_obs and isinstance(observation, dict):
|
215
|
+
observation.update(extra_obs)
|
216
|
+
return observation
|
217
|
+
|
218
|
+
async def _serialize_engine(self) -> SokobanEngineSnapshot: # Changed type hint
|
219
|
+
return await self.engine._serialize_engine()
|
220
|
+
|
221
|
+
@classmethod
|
222
|
+
async def _deserialize_engine(
|
223
|
+
cls, snapshot: SokobanEngineSnapshot, task_instance: TaskInstance
|
224
|
+
) -> "SokobanEnvironment": # Changed type hint
|
225
|
+
eng = await SokobanEngine._deserialize_engine(snapshot, task_instance)
|
226
|
+
env = cls(task_instance) # Uses task_instance from deserialized engine
|
227
|
+
env.engine = eng
|
228
|
+
return env
|
@@ -0,0 +1,438 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Generate verified solvable Sokoban puzzles.
|
4
|
+
|
5
|
+
This script creates 500 solvable Sokoban puzzles (100 each for 5 difficulty levels)
|
6
|
+
and saves them as JSON. Each puzzle is verified to be solvable using BFS.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import json
|
10
|
+
import logging
|
11
|
+
import numpy as np
|
12
|
+
from typing import Dict, List, Tuple, Optional, Any, Set
|
13
|
+
from pathlib import Path
|
14
|
+
from dataclasses import dataclass, asdict
|
15
|
+
from synth_ai.environments.examples.sokoban.engine_helpers.room_utils import (
|
16
|
+
generate_room,
|
17
|
+
get_shortest_action_path,
|
18
|
+
)
|
19
|
+
|
20
|
+
# Set up logging
|
21
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class SokobanPuzzle:
|
27
|
+
"""Represents a verified solvable Sokoban puzzle."""
|
28
|
+
|
29
|
+
id: str
|
30
|
+
difficulty: str
|
31
|
+
num_boxes: int
|
32
|
+
dim_room: Tuple[int, int]
|
33
|
+
room_fixed: List[List[int]]
|
34
|
+
room_state: List[List[int]]
|
35
|
+
box_mapping: Dict[str, List[int]]
|
36
|
+
solution_path: List[int]
|
37
|
+
solution_length: int
|
38
|
+
generation_seed: int
|
39
|
+
max_steps: int
|
40
|
+
|
41
|
+
|
42
|
+
# Define difficulty configurations
|
43
|
+
DIFFICULTY_CONFIGS = {
|
44
|
+
"ultra_easy": {
|
45
|
+
"num_boxes": 1,
|
46
|
+
"dim_room": (5, 5),
|
47
|
+
"max_steps": 50,
|
48
|
+
"target_solution_length": (3, 8),
|
49
|
+
"search_depth": 30,
|
50
|
+
},
|
51
|
+
"easy": {
|
52
|
+
"num_boxes": 1,
|
53
|
+
"dim_room": (6, 6),
|
54
|
+
"max_steps": 80,
|
55
|
+
"target_solution_length": (8, 15),
|
56
|
+
"search_depth": 50,
|
57
|
+
},
|
58
|
+
"medium": {
|
59
|
+
"num_boxes": 2,
|
60
|
+
"dim_room": (7, 7),
|
61
|
+
"max_steps": 120,
|
62
|
+
"target_solution_length": (15, 30),
|
63
|
+
"search_depth": 80,
|
64
|
+
},
|
65
|
+
"hard": {
|
66
|
+
"num_boxes": 3,
|
67
|
+
"dim_room": (8, 8),
|
68
|
+
"max_steps": 200,
|
69
|
+
"target_solution_length": (30, 60),
|
70
|
+
"search_depth": 120,
|
71
|
+
},
|
72
|
+
}
|
73
|
+
|
74
|
+
|
75
|
+
def verify_puzzle_solvable(
|
76
|
+
room_fixed: np.ndarray, room_state: np.ndarray, max_depth: int = 200
|
77
|
+
) -> Optional[List[int]]:
|
78
|
+
"""
|
79
|
+
Verify that a puzzle is solvable using BFS and return the solution path.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
room_fixed: The fixed room structure (walls, targets, floors)
|
83
|
+
room_state: The current room state (player, boxes)
|
84
|
+
max_depth: Maximum search depth
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
List of actions if solvable, None if not solvable
|
88
|
+
"""
|
89
|
+
try:
|
90
|
+
solution_path = get_shortest_action_path(room_fixed, room_state, MAX_DEPTH=max_depth)
|
91
|
+
return solution_path if solution_path else None
|
92
|
+
except Exception as e:
|
93
|
+
logger.warning(f"Error verifying puzzle: {e}")
|
94
|
+
return None
|
95
|
+
|
96
|
+
|
97
|
+
def setup_instances_directory() -> Path:
|
98
|
+
"""Create the instances directory if it doesn't exist."""
|
99
|
+
instances_dir = Path(__file__).parent / "instances"
|
100
|
+
instances_dir.mkdir(exist_ok=True)
|
101
|
+
return instances_dir
|
102
|
+
|
103
|
+
|
104
|
+
def get_jsonl_path(instances_dir: Path, difficulty: str) -> Path:
|
105
|
+
"""Get the JSONL file path for a difficulty level."""
|
106
|
+
return instances_dir / f"{difficulty}.jsonl"
|
107
|
+
|
108
|
+
|
109
|
+
def save_puzzle_to_jsonl(puzzle: SokobanPuzzle, jsonl_path: Path):
|
110
|
+
"""Save a single puzzle to a JSONL file."""
|
111
|
+
with open(jsonl_path, "a") as f:
|
112
|
+
f.write(json.dumps(asdict(puzzle), default=convert_numpy_types) + "\n")
|
113
|
+
|
114
|
+
|
115
|
+
def convert_numpy_types(obj):
|
116
|
+
"""Convert numpy types to Python types for JSON serialization."""
|
117
|
+
if isinstance(obj, np.integer):
|
118
|
+
return int(obj)
|
119
|
+
elif isinstance(obj, np.floating):
|
120
|
+
return float(obj)
|
121
|
+
elif isinstance(obj, np.ndarray):
|
122
|
+
return obj.tolist()
|
123
|
+
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
124
|
+
|
125
|
+
|
126
|
+
def load_existing_puzzles(jsonl_path: Path) -> Set[str]:
|
127
|
+
"""Load existing puzzle IDs from a JSONL file."""
|
128
|
+
existing_ids = set()
|
129
|
+
if jsonl_path.exists():
|
130
|
+
with open(jsonl_path, "r") as f:
|
131
|
+
for line in f:
|
132
|
+
try:
|
133
|
+
puzzle_data = json.loads(line.strip())
|
134
|
+
existing_ids.add(puzzle_data["id"])
|
135
|
+
except json.JSONDecodeError:
|
136
|
+
continue
|
137
|
+
return existing_ids
|
138
|
+
|
139
|
+
|
140
|
+
def count_existing_puzzles(jsonl_path: Path) -> int:
|
141
|
+
"""Count existing puzzles in a JSONL file."""
|
142
|
+
if not jsonl_path.exists():
|
143
|
+
return 0
|
144
|
+
with open(jsonl_path, "r") as f:
|
145
|
+
return sum(1 for line in f if line.strip())
|
146
|
+
|
147
|
+
|
148
|
+
def generate_puzzle_for_difficulty(
|
149
|
+
difficulty: str, config: Dict, seed: int, puzzle_id: str
|
150
|
+
) -> Optional[SokobanPuzzle]:
|
151
|
+
"""
|
152
|
+
Generate a single puzzle for a given difficulty level.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
difficulty: The difficulty level name
|
156
|
+
config: Configuration for this difficulty
|
157
|
+
seed: Random seed for generation
|
158
|
+
puzzle_id: Unique identifier for this puzzle
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
SokobanPuzzle if successfully generated and verified, None otherwise
|
162
|
+
"""
|
163
|
+
max_attempts = 20
|
164
|
+
|
165
|
+
for attempt in range(max_attempts):
|
166
|
+
current_seed = seed + attempt * 1000
|
167
|
+
|
168
|
+
try:
|
169
|
+
# Generate room
|
170
|
+
room_structure, room_state, box_mapping, action_sequence = generate_room(
|
171
|
+
dim=config["dim_room"],
|
172
|
+
initial_seed=current_seed,
|
173
|
+
num_boxes=config["num_boxes"],
|
174
|
+
search_depth=config["search_depth"],
|
175
|
+
num_steps=config["search_depth"] // 2,
|
176
|
+
)
|
177
|
+
|
178
|
+
# Verify solvability
|
179
|
+
solution_path = verify_puzzle_solvable(
|
180
|
+
room_structure, room_state, max_depth=config["max_steps"]
|
181
|
+
)
|
182
|
+
|
183
|
+
if solution_path is None:
|
184
|
+
logger.debug(f"Puzzle {puzzle_id} attempt {attempt + 1} not solvable")
|
185
|
+
continue
|
186
|
+
|
187
|
+
solution_length = len(solution_path)
|
188
|
+
target_min, target_max = config["target_solution_length"]
|
189
|
+
|
190
|
+
# Check if solution length is within desired range
|
191
|
+
if not (target_min <= solution_length <= target_max):
|
192
|
+
logger.debug(
|
193
|
+
f"Puzzle {puzzle_id} attempt {attempt + 1} solution length {solution_length} not in range {target_min}-{target_max}"
|
194
|
+
)
|
195
|
+
continue
|
196
|
+
|
197
|
+
# Convert numpy arrays to lists for JSON serialization
|
198
|
+
room_fixed_list = room_structure.tolist()
|
199
|
+
room_state_list = room_state.tolist()
|
200
|
+
|
201
|
+
# Convert box mapping to serializable format
|
202
|
+
box_mapping_serializable = {}
|
203
|
+
for key, value in box_mapping.items():
|
204
|
+
if isinstance(key, tuple):
|
205
|
+
# Convert numpy integers to regular integers
|
206
|
+
key_str = f"{int(key[0])},{int(key[1])}"
|
207
|
+
if isinstance(value, tuple):
|
208
|
+
box_mapping_serializable[key_str] = [int(value[0]), int(value[1])]
|
209
|
+
else:
|
210
|
+
box_mapping_serializable[key_str] = value
|
211
|
+
else:
|
212
|
+
box_mapping_serializable[str(key)] = value
|
213
|
+
|
214
|
+
puzzle = SokobanPuzzle(
|
215
|
+
id=puzzle_id,
|
216
|
+
difficulty=difficulty,
|
217
|
+
num_boxes=int(config["num_boxes"]),
|
218
|
+
dim_room=config["dim_room"],
|
219
|
+
room_fixed=room_fixed_list,
|
220
|
+
room_state=room_state_list,
|
221
|
+
box_mapping=box_mapping_serializable,
|
222
|
+
solution_path=[int(action) for action in solution_path], # Convert to regular ints
|
223
|
+
solution_length=int(solution_length),
|
224
|
+
generation_seed=int(current_seed),
|
225
|
+
max_steps=int(config["max_steps"]),
|
226
|
+
)
|
227
|
+
|
228
|
+
logger.info(
|
229
|
+
f"Generated {difficulty} puzzle {puzzle_id} (seed: {current_seed}, solution length: {solution_length})"
|
230
|
+
)
|
231
|
+
return puzzle
|
232
|
+
|
233
|
+
except Exception as e:
|
234
|
+
logger.warning(f"Error generating puzzle {puzzle_id} attempt {attempt + 1}: {e}")
|
235
|
+
continue
|
236
|
+
|
237
|
+
logger.error(f"Failed to generate puzzle {puzzle_id} after {max_attempts} attempts")
|
238
|
+
return None
|
239
|
+
|
240
|
+
|
241
|
+
def generate_all_puzzles(num_per_difficulty: int = 100) -> Dict[str, List[SokobanPuzzle]]:
|
242
|
+
"""
|
243
|
+
Generate all puzzles for all difficulty levels with incremental saving.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
num_per_difficulty: Number of puzzles to generate per difficulty level
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
Dictionary mapping difficulty names to lists of puzzles
|
250
|
+
"""
|
251
|
+
all_puzzles = {}
|
252
|
+
total_puzzles = 0
|
253
|
+
|
254
|
+
# Setup instances directory
|
255
|
+
instances_dir = setup_instances_directory()
|
256
|
+
logger.info(f"Using instances directory: {instances_dir}")
|
257
|
+
|
258
|
+
for difficulty, config in DIFFICULTY_CONFIGS.items():
|
259
|
+
jsonl_path = get_jsonl_path(instances_dir, difficulty)
|
260
|
+
existing_ids = load_existing_puzzles(jsonl_path)
|
261
|
+
existing_count = count_existing_puzzles(jsonl_path)
|
262
|
+
|
263
|
+
logger.info(f"Processing {difficulty} difficulty...")
|
264
|
+
logger.info(f" Found {existing_count} existing puzzles")
|
265
|
+
logger.info(f" Target: {num_per_difficulty} puzzles")
|
266
|
+
|
267
|
+
puzzles = []
|
268
|
+
base_seed = hash(difficulty) % 100000
|
269
|
+
|
270
|
+
# Generate puzzles until we have enough
|
271
|
+
i = 0
|
272
|
+
generated_this_run = 0
|
273
|
+
while (
|
274
|
+
len(puzzles) + existing_count < num_per_difficulty and i < num_per_difficulty * 5
|
275
|
+
): # Safety limit
|
276
|
+
puzzle_id = f"{difficulty}_{i:03d}"
|
277
|
+
|
278
|
+
# Skip if already exists
|
279
|
+
if puzzle_id in existing_ids:
|
280
|
+
i += 1
|
281
|
+
continue
|
282
|
+
|
283
|
+
puzzle = generate_puzzle_for_difficulty(
|
284
|
+
difficulty=difficulty, config=config, seed=base_seed + i, puzzle_id=puzzle_id
|
285
|
+
)
|
286
|
+
|
287
|
+
if puzzle:
|
288
|
+
puzzles.append(puzzle)
|
289
|
+
# Save immediately to JSONL
|
290
|
+
save_puzzle_to_jsonl(puzzle, jsonl_path)
|
291
|
+
generated_this_run += 1
|
292
|
+
total_puzzles += 1
|
293
|
+
logger.info(
|
294
|
+
f"Generated and saved {difficulty} puzzle {puzzle_id} ({generated_this_run}/{num_per_difficulty - existing_count} new)"
|
295
|
+
)
|
296
|
+
else:
|
297
|
+
logger.warning(f"Failed to generate puzzle {puzzle_id}")
|
298
|
+
|
299
|
+
i += 1
|
300
|
+
|
301
|
+
all_puzzles[difficulty] = puzzles
|
302
|
+
logger.info(
|
303
|
+
f"Completed {difficulty}: {generated_this_run} new puzzles generated, {existing_count + len(puzzles)} total"
|
304
|
+
)
|
305
|
+
|
306
|
+
logger.info(f"Total new puzzles generated this run: {total_puzzles}")
|
307
|
+
return all_puzzles
|
308
|
+
|
309
|
+
|
310
|
+
def load_all_puzzles_from_jsonl(instances_dir: Path) -> Dict[str, List[SokobanPuzzle]]:
|
311
|
+
"""Load all puzzles from JSONL files."""
|
312
|
+
all_puzzles = {}
|
313
|
+
|
314
|
+
for difficulty in DIFFICULTY_CONFIGS.keys():
|
315
|
+
jsonl_path = get_jsonl_path(instances_dir, difficulty)
|
316
|
+
puzzles = []
|
317
|
+
|
318
|
+
if jsonl_path.exists():
|
319
|
+
with open(jsonl_path, "r") as f:
|
320
|
+
for line in f:
|
321
|
+
try:
|
322
|
+
puzzle_data = json.loads(line.strip())
|
323
|
+
puzzle = SokobanPuzzle(
|
324
|
+
id=puzzle_data["id"],
|
325
|
+
difficulty=puzzle_data["difficulty"],
|
326
|
+
num_boxes=puzzle_data["num_boxes"],
|
327
|
+
dim_room=tuple(puzzle_data["dim_room"]),
|
328
|
+
room_fixed=puzzle_data["room_fixed"],
|
329
|
+
room_state=puzzle_data["room_state"],
|
330
|
+
box_mapping=puzzle_data["box_mapping"],
|
331
|
+
solution_path=puzzle_data["solution_path"],
|
332
|
+
solution_length=puzzle_data["solution_length"],
|
333
|
+
generation_seed=puzzle_data["generation_seed"],
|
334
|
+
max_steps=puzzle_data["max_steps"],
|
335
|
+
)
|
336
|
+
puzzles.append(puzzle)
|
337
|
+
except (json.JSONDecodeError, KeyError) as e:
|
338
|
+
logger.warning(f"Error loading puzzle from {jsonl_path}: {e}")
|
339
|
+
continue
|
340
|
+
|
341
|
+
all_puzzles[difficulty] = puzzles
|
342
|
+
|
343
|
+
return all_puzzles
|
344
|
+
|
345
|
+
|
346
|
+
def save_puzzles_to_json(puzzles: Dict[str, List[SokobanPuzzle]], output_path: Path):
|
347
|
+
"""
|
348
|
+
Save puzzles to JSON file.
|
349
|
+
|
350
|
+
Args:
|
351
|
+
puzzles: Dictionary of puzzles by difficulty
|
352
|
+
output_path: Path to save the JSON file
|
353
|
+
"""
|
354
|
+
# Convert to serializable format
|
355
|
+
serializable_puzzles = {}
|
356
|
+
for difficulty, puzzle_list in puzzles.items():
|
357
|
+
serializable_puzzles[difficulty] = [asdict(puzzle) for puzzle in puzzle_list]
|
358
|
+
|
359
|
+
# Add metadata
|
360
|
+
output_data = {
|
361
|
+
"metadata": {
|
362
|
+
"version": "1.0",
|
363
|
+
"total_puzzles": sum(len(puzzles) for puzzles in serializable_puzzles.values()),
|
364
|
+
"difficulties": list(DIFFICULTY_CONFIGS.keys()),
|
365
|
+
"generated_at": "2024-01-01T00:00:00Z", # Will be updated when actually generated
|
366
|
+
},
|
367
|
+
"puzzles": serializable_puzzles,
|
368
|
+
}
|
369
|
+
|
370
|
+
with open(output_path, "w") as f:
|
371
|
+
json.dump(output_data, f, indent=2, default=convert_numpy_types)
|
372
|
+
|
373
|
+
logger.info(f"Saved puzzles to {output_path}")
|
374
|
+
|
375
|
+
|
376
|
+
def create_unified_json_from_jsonl():
|
377
|
+
"""Create a unified JSON file from all JSONL files for the puzzle loader."""
|
378
|
+
instances_dir = setup_instances_directory()
|
379
|
+
all_puzzles = load_all_puzzles_from_jsonl(instances_dir)
|
380
|
+
|
381
|
+
# Save to JSON
|
382
|
+
output_path = Path(__file__).parent / "verified_puzzles.json"
|
383
|
+
save_puzzles_to_json(all_puzzles, output_path)
|
384
|
+
|
385
|
+
return all_puzzles
|
386
|
+
|
387
|
+
|
388
|
+
def main():
|
389
|
+
"""Main function to generate and save all puzzles."""
|
390
|
+
logger.info("Starting Sokoban puzzle generation with incremental saving...")
|
391
|
+
|
392
|
+
# Generate puzzles (saves incrementally to JSONL)
|
393
|
+
puzzles = generate_all_puzzles(num_per_difficulty=100)
|
394
|
+
|
395
|
+
# Print summary of this run
|
396
|
+
logger.info("Puzzle generation complete!")
|
397
|
+
logger.info("Summary of this run:")
|
398
|
+
for difficulty, puzzle_list in puzzles.items():
|
399
|
+
if puzzle_list:
|
400
|
+
avg_solution_length = sum(p.solution_length for p in puzzle_list) / len(puzzle_list)
|
401
|
+
logger.info(
|
402
|
+
f" {difficulty}: {len(puzzle_list)} new puzzles, avg solution length: {avg_solution_length:.1f}"
|
403
|
+
)
|
404
|
+
else:
|
405
|
+
logger.info(f" {difficulty}: 0 new puzzles")
|
406
|
+
|
407
|
+
# Show total counts from JSONL files
|
408
|
+
instances_dir = setup_instances_directory()
|
409
|
+
logger.info("Total puzzles saved:")
|
410
|
+
for difficulty in DIFFICULTY_CONFIGS.keys():
|
411
|
+
jsonl_path = get_jsonl_path(instances_dir, difficulty)
|
412
|
+
total_count = count_existing_puzzles(jsonl_path)
|
413
|
+
logger.info(f" {difficulty}: {total_count} total puzzles")
|
414
|
+
|
415
|
+
# Create unified JSON file for the puzzle loader
|
416
|
+
logger.info("Creating unified JSON file for puzzle loader...")
|
417
|
+
create_unified_json_from_jsonl()
|
418
|
+
logger.info("Unified JSON file created successfully!")
|
419
|
+
|
420
|
+
|
421
|
+
if __name__ == "__main__":
|
422
|
+
import sys
|
423
|
+
|
424
|
+
if len(sys.argv) > 1 and sys.argv[1] == "--create-json":
|
425
|
+
# Just create the unified JSON from existing JSONL files
|
426
|
+
logger.info("Creating unified JSON file from existing JSONL files...")
|
427
|
+
puzzles = create_unified_json_from_jsonl()
|
428
|
+
logger.info("Summary of loaded puzzles:")
|
429
|
+
for difficulty, puzzle_list in puzzles.items():
|
430
|
+
if puzzle_list:
|
431
|
+
avg_solution_length = sum(p.solution_length for p in puzzle_list) / len(puzzle_list)
|
432
|
+
logger.info(
|
433
|
+
f" {difficulty}: {len(puzzle_list)} puzzles, avg solution length: {avg_solution_length:.1f}"
|
434
|
+
)
|
435
|
+
else:
|
436
|
+
logger.info(f" {difficulty}: 0 puzzles")
|
437
|
+
else:
|
438
|
+
main()
|