synth-ai 0.2.4.dev3__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +575 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/tracing_v3/examples/basic_usage.py +188 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +105 -6
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,425 @@
|
|
1
|
+
from synth_ai.environments.tasks.core import (
|
2
|
+
Task,
|
3
|
+
TaskInstance,
|
4
|
+
TaskInstanceMetadata,
|
5
|
+
TaskInstanceMetadataFilter,
|
6
|
+
TaskInstanceSet,
|
7
|
+
)
|
8
|
+
from uuid import uuid4, UUID
|
9
|
+
from synth_ai.environments.tasks.core import SplitInfo, Impetus, Intent
|
10
|
+
from synth_ai.environments.examples.sokoban.puzzle_loader import (
|
11
|
+
get_puzzle_loader,
|
12
|
+
SokobanPuzzle,
|
13
|
+
)
|
14
|
+
from dataclasses import dataclass, asdict, fields
|
15
|
+
from typing import Tuple, List
|
16
|
+
import os
|
17
|
+
import logging
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
sokoban_task = Task(
|
22
|
+
global_premises="Procedural Sokoban task generation",
|
23
|
+
global_constraints="",
|
24
|
+
global_objectives="Push all boxes onto target locations",
|
25
|
+
shared_env_params={},
|
26
|
+
)
|
27
|
+
|
28
|
+
# Configuration parameters
|
29
|
+
NUM_INSTANCES_PER_DIFFICULTY = 10 # Number of puzzles to include per difficulty in the taskset
|
30
|
+
DIFFICULTY_CONFIGS = {
|
31
|
+
"ultra_easy": {
|
32
|
+
"impetus_prompt": "Solve this very simple Sokoban puzzle by pushing the box onto the target.",
|
33
|
+
},
|
34
|
+
"easy": {
|
35
|
+
"impetus_prompt": "Solve this simple Sokoban puzzle by pushing the box onto the target.",
|
36
|
+
},
|
37
|
+
"medium": {
|
38
|
+
"impetus_prompt": "Solve this Sokoban puzzle by pushing the 2 boxes onto the targets.",
|
39
|
+
},
|
40
|
+
"hard": {
|
41
|
+
"impetus_prompt": "Solve this challenging Sokoban puzzle by pushing the 3 boxes onto the targets.",
|
42
|
+
},
|
43
|
+
}
|
44
|
+
|
45
|
+
|
46
|
+
@dataclass
|
47
|
+
class SokobanTaskInstanceMetadata(TaskInstanceMetadata):
|
48
|
+
difficulty: str
|
49
|
+
num_boxes: int
|
50
|
+
dim_room: Tuple[int, int]
|
51
|
+
max_steps: int
|
52
|
+
shortest_path_length: int
|
53
|
+
seed: int
|
54
|
+
generation_params: str
|
55
|
+
|
56
|
+
|
57
|
+
@dataclass
|
58
|
+
class SokobanTaskInstance(TaskInstance):
|
59
|
+
async def serialize(self) -> dict:
|
60
|
+
data = asdict(self)
|
61
|
+
if "id" in data and isinstance(data["id"], UUID):
|
62
|
+
data["id"] = str(data["id"])
|
63
|
+
if "intent" in data and data["intent"] is not None:
|
64
|
+
if "deterministic_eval_functions" in data["intent"]:
|
65
|
+
data["intent"]["deterministic_eval_functions"] = []
|
66
|
+
return data
|
67
|
+
|
68
|
+
@classmethod
|
69
|
+
async def deserialize(cls, data: dict) -> "SokobanTaskInstance":
|
70
|
+
"""Gracefully accept non-UUID ids (e.g. 'demo-mcts')."""
|
71
|
+
if "id" in data:
|
72
|
+
try:
|
73
|
+
data["id"] = UUID(str(data["id"]))
|
74
|
+
except (ValueError, TypeError, AttributeError):
|
75
|
+
pass # keep original string
|
76
|
+
|
77
|
+
if "impetus" in data and isinstance(data["impetus"], dict):
|
78
|
+
data["impetus"] = Impetus(**data["impetus"])
|
79
|
+
|
80
|
+
if "intent" in data and isinstance(data["intent"], dict):
|
81
|
+
intent_data = data["intent"]
|
82
|
+
intent_data["deterministic_eval_functions"] = []
|
83
|
+
if "gold_trajectories" in intent_data and intent_data["gold_trajectories"] is not None:
|
84
|
+
pass
|
85
|
+
data["intent"] = Intent(**intent_data)
|
86
|
+
|
87
|
+
if "metadata" in data and isinstance(data["metadata"], dict):
|
88
|
+
data["metadata"] = SokobanTaskInstanceMetadata(**data["metadata"])
|
89
|
+
|
90
|
+
constructor_field_names = {f.name for f in fields(cls)}
|
91
|
+
filtered_data = {k: v for k, v in data.items() if k in constructor_field_names}
|
92
|
+
|
93
|
+
return cls(**filtered_data)
|
94
|
+
|
95
|
+
|
96
|
+
async def create_sokoban_taskset() -> TaskInstanceSet:
|
97
|
+
"""Generates Sokoban task instances from pre-generated verified puzzles."""
|
98
|
+
instances = []
|
99
|
+
|
100
|
+
# Load pre-generated puzzles
|
101
|
+
try:
|
102
|
+
puzzle_loader = get_puzzle_loader()
|
103
|
+
logger.info("Loading pre-generated Sokoban puzzles...")
|
104
|
+
puzzle_loader.load_puzzles()
|
105
|
+
logger.info(f"Loaded {puzzle_loader.get_total_puzzle_count()} total puzzles")
|
106
|
+
except Exception as e:
|
107
|
+
logger.error(f"Failed to load pre-generated puzzles: {e}")
|
108
|
+
logger.info("Falling back to empty taskset. Run generate_verified_puzzles.py first.")
|
109
|
+
return TaskInstanceSet(
|
110
|
+
name="Sokoban Verified TaskSet",
|
111
|
+
description="Verified pre-generated Sokoban tasks with guaranteed solvability.",
|
112
|
+
instances=[],
|
113
|
+
split_info=SplitInfo(
|
114
|
+
val_instance_ids=set(), test_instance_ids=set(), _is_split_defined=True
|
115
|
+
),
|
116
|
+
)
|
117
|
+
|
118
|
+
for difficulty, config in DIFFICULTY_CONFIGS.items():
|
119
|
+
available_puzzles = puzzle_loader.get_puzzles_by_difficulty(difficulty)
|
120
|
+
|
121
|
+
if not available_puzzles:
|
122
|
+
logger.warning(f"No puzzles found for difficulty {difficulty}")
|
123
|
+
continue
|
124
|
+
|
125
|
+
# Take up to NUM_INSTANCES_PER_DIFFICULTY puzzles
|
126
|
+
puzzles_to_use = available_puzzles[:NUM_INSTANCES_PER_DIFFICULTY]
|
127
|
+
logger.info(f"Using {len(puzzles_to_use)} puzzles for {difficulty} difficulty")
|
128
|
+
|
129
|
+
for puzzle in puzzles_to_use:
|
130
|
+
instance_id = uuid4()
|
131
|
+
|
132
|
+
impetus = Impetus(instructions=config["impetus_prompt"])
|
133
|
+
intent = Intent(
|
134
|
+
rubric={"goal": "Push all boxes onto target locations."},
|
135
|
+
gold_trajectories=None,
|
136
|
+
gold_state_diff={},
|
137
|
+
)
|
138
|
+
metadata = SokobanTaskInstanceMetadata(
|
139
|
+
difficulty=difficulty,
|
140
|
+
num_boxes=puzzle.num_boxes,
|
141
|
+
dim_room=puzzle.dim_room,
|
142
|
+
max_steps=puzzle.max_steps,
|
143
|
+
shortest_path_length=puzzle.solution_length,
|
144
|
+
seed=puzzle.generation_seed,
|
145
|
+
generation_params=f"verified_puzzle_id={puzzle.id}",
|
146
|
+
)
|
147
|
+
|
148
|
+
# Use the puzzle data as the initial engine snapshot
|
149
|
+
initial_engine_snapshot = puzzle.to_engine_snapshot()
|
150
|
+
|
151
|
+
task_instance = SokobanTaskInstance(
|
152
|
+
id=instance_id,
|
153
|
+
impetus=impetus,
|
154
|
+
intent=intent,
|
155
|
+
metadata=metadata,
|
156
|
+
is_reproducible=True,
|
157
|
+
initial_engine_snapshot=initial_engine_snapshot,
|
158
|
+
)
|
159
|
+
instances.append(task_instance)
|
160
|
+
|
161
|
+
class NumBoxesFilter(TaskInstanceMetadataFilter):
|
162
|
+
def __init__(self, num_boxes):
|
163
|
+
self.num_boxes = num_boxes
|
164
|
+
|
165
|
+
def __call__(self, instance):
|
166
|
+
if hasattr(instance.metadata, "num_boxes"):
|
167
|
+
return instance.metadata.num_boxes == self.num_boxes
|
168
|
+
return False
|
169
|
+
|
170
|
+
class DimRoomFilter(TaskInstanceMetadataFilter):
|
171
|
+
def __init__(self, dim_room):
|
172
|
+
self.dim_room = dim_room
|
173
|
+
|
174
|
+
def __call__(self, instance):
|
175
|
+
if hasattr(instance.metadata, "dim_room"):
|
176
|
+
return instance.metadata.dim_room == self.dim_room
|
177
|
+
return False
|
178
|
+
|
179
|
+
class PathLengthFilter(TaskInstanceMetadataFilter):
|
180
|
+
def __init__(self, min_length=None, max_length=None):
|
181
|
+
self.min_length = min_length
|
182
|
+
self.max_length = max_length
|
183
|
+
|
184
|
+
def __call__(self, instance):
|
185
|
+
if not hasattr(instance.metadata, "shortest_path_length"):
|
186
|
+
return False
|
187
|
+
length = instance.metadata.shortest_path_length
|
188
|
+
if self.min_length is not None and length < self.min_length:
|
189
|
+
return False
|
190
|
+
if self.max_length is not None and length > self.max_length:
|
191
|
+
return False
|
192
|
+
return True
|
193
|
+
|
194
|
+
val_filter = NumBoxesFilter(2)
|
195
|
+
test_filter = PathLengthFilter(max_length=10)
|
196
|
+
val_ids = {inst.id for inst in instances if val_filter(inst)}
|
197
|
+
# remove anything already tagged as validation
|
198
|
+
test_ids = {inst.id for inst in instances if test_filter(inst) and inst.id not in val_ids}
|
199
|
+
split_info = SplitInfo(
|
200
|
+
val_instance_ids=val_ids,
|
201
|
+
test_instance_ids=test_ids,
|
202
|
+
_is_split_defined=True,
|
203
|
+
)
|
204
|
+
|
205
|
+
return TaskInstanceSet(
|
206
|
+
name="Sokoban Verified TaskSet",
|
207
|
+
description="Verified pre-generated Sokoban tasks with guaranteed solvability.",
|
208
|
+
instances=instances,
|
209
|
+
split_info=split_info,
|
210
|
+
)
|
211
|
+
|
212
|
+
|
213
|
+
async def create_easy_sokoban_taskset(num_instances: int = 50) -> TaskInstanceSet:
|
214
|
+
"""Create a taskset with only easy difficulty puzzles."""
|
215
|
+
return await create_filtered_sokoban_taskset(
|
216
|
+
difficulties=["easy"], num_instances_per_difficulty=num_instances
|
217
|
+
)
|
218
|
+
|
219
|
+
|
220
|
+
async def create_filtered_sokoban_taskset(
|
221
|
+
difficulties: List[str], num_instances_per_difficulty: int = 10
|
222
|
+
) -> TaskInstanceSet:
|
223
|
+
"""
|
224
|
+
Create a taskset with only specified difficulties.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
difficulties: List of difficulty levels to include
|
228
|
+
num_instances_per_difficulty: Number of instances per difficulty
|
229
|
+
|
230
|
+
Returns:
|
231
|
+
TaskInstanceSet with only the specified difficulties
|
232
|
+
"""
|
233
|
+
instances = []
|
234
|
+
|
235
|
+
# Load pre-generated puzzles
|
236
|
+
try:
|
237
|
+
puzzle_loader = get_puzzle_loader()
|
238
|
+
logger.info("Loading pre-generated Sokoban puzzles...")
|
239
|
+
puzzle_loader.load_puzzles()
|
240
|
+
logger.info(f"Loaded {puzzle_loader.get_total_puzzle_count()} total puzzles")
|
241
|
+
except Exception as e:
|
242
|
+
logger.error(f"Failed to load pre-generated puzzles: {e}")
|
243
|
+
return TaskInstanceSet(
|
244
|
+
name="Sokoban Filtered TaskSet",
|
245
|
+
description=f"Filtered Sokoban tasks for difficulties: {', '.join(difficulties)}",
|
246
|
+
instances=[],
|
247
|
+
split_info=SplitInfo(
|
248
|
+
val_instance_ids=set(), test_instance_ids=set(), _is_split_defined=True
|
249
|
+
),
|
250
|
+
)
|
251
|
+
|
252
|
+
for difficulty in difficulties:
|
253
|
+
if difficulty not in DIFFICULTY_CONFIGS:
|
254
|
+
logger.warning(f"Unknown difficulty '{difficulty}', skipping")
|
255
|
+
continue
|
256
|
+
|
257
|
+
config = DIFFICULTY_CONFIGS[difficulty]
|
258
|
+
available_puzzles = puzzle_loader.get_puzzles_by_difficulty(difficulty)
|
259
|
+
|
260
|
+
if not available_puzzles:
|
261
|
+
logger.warning(f"No puzzles found for difficulty {difficulty}")
|
262
|
+
continue
|
263
|
+
|
264
|
+
# Take up to num_instances_per_difficulty puzzles
|
265
|
+
puzzles_to_use = available_puzzles[:num_instances_per_difficulty]
|
266
|
+
logger.info(f"Using {len(puzzles_to_use)} puzzles for {difficulty} difficulty")
|
267
|
+
|
268
|
+
for puzzle in puzzles_to_use:
|
269
|
+
instance_id = uuid4()
|
270
|
+
|
271
|
+
impetus = Impetus(instructions=config["impetus_prompt"])
|
272
|
+
intent = Intent(
|
273
|
+
rubric={"goal": "Push all boxes onto target locations."},
|
274
|
+
gold_trajectories=None,
|
275
|
+
gold_state_diff={},
|
276
|
+
)
|
277
|
+
metadata = SokobanTaskInstanceMetadata(
|
278
|
+
difficulty=difficulty,
|
279
|
+
num_boxes=puzzle.num_boxes,
|
280
|
+
dim_room=puzzle.dim_room,
|
281
|
+
max_steps=puzzle.max_steps,
|
282
|
+
shortest_path_length=puzzle.solution_length,
|
283
|
+
seed=puzzle.generation_seed,
|
284
|
+
generation_params=f"verified_puzzle_id={puzzle.id}",
|
285
|
+
)
|
286
|
+
|
287
|
+
# Use the puzzle data as the initial engine snapshot
|
288
|
+
initial_engine_snapshot = puzzle.to_engine_snapshot()
|
289
|
+
|
290
|
+
task_instance = SokobanTaskInstance(
|
291
|
+
id=instance_id,
|
292
|
+
impetus=impetus,
|
293
|
+
intent=intent,
|
294
|
+
metadata=metadata,
|
295
|
+
is_reproducible=True,
|
296
|
+
initial_engine_snapshot=initial_engine_snapshot,
|
297
|
+
)
|
298
|
+
instances.append(task_instance)
|
299
|
+
|
300
|
+
# Create simple split info for filtered set
|
301
|
+
val_ids = {inst.id for inst in instances[::3]} # Every 3rd instance for validation
|
302
|
+
test_ids = {inst.id for inst in instances[1::3]} # Every 3rd starting from 1 for test
|
303
|
+
split_info = SplitInfo(
|
304
|
+
val_instance_ids=val_ids,
|
305
|
+
test_instance_ids=test_ids,
|
306
|
+
_is_split_defined=True,
|
307
|
+
)
|
308
|
+
|
309
|
+
return TaskInstanceSet(
|
310
|
+
name="Sokoban Filtered TaskSet",
|
311
|
+
description=f"Filtered Sokoban tasks for difficulties: {', '.join(difficulties)}",
|
312
|
+
instances=instances,
|
313
|
+
split_info=split_info,
|
314
|
+
)
|
315
|
+
|
316
|
+
|
317
|
+
async def create_task_instance_from_seed(difficulty: str, seed: int) -> SokobanTaskInstance:
|
318
|
+
"""
|
319
|
+
Create a single task instance from a specific seed.
|
320
|
+
Uses modular arithmetic to deterministically select a puzzle.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
difficulty: The difficulty level
|
324
|
+
seed: Seed for deterministic puzzle selection
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
Single SokobanTaskInstance
|
328
|
+
"""
|
329
|
+
from synth_ai.environments.examples.sokoban.puzzle_loader import get_puzzle_by_seed
|
330
|
+
|
331
|
+
puzzle = get_puzzle_by_seed(difficulty, seed)
|
332
|
+
if not puzzle:
|
333
|
+
raise ValueError(f"No puzzles available for difficulty '{difficulty}'")
|
334
|
+
|
335
|
+
config = DIFFICULTY_CONFIGS.get(difficulty)
|
336
|
+
if not config:
|
337
|
+
raise ValueError(f"Unknown difficulty '{difficulty}'")
|
338
|
+
|
339
|
+
instance_id = uuid4()
|
340
|
+
|
341
|
+
impetus = Impetus(instructions=config["impetus_prompt"])
|
342
|
+
intent = Intent(
|
343
|
+
rubric={"goal": "Push all boxes onto target locations."},
|
344
|
+
gold_trajectories=None,
|
345
|
+
gold_state_diff={},
|
346
|
+
)
|
347
|
+
metadata = SokobanTaskInstanceMetadata(
|
348
|
+
difficulty=difficulty,
|
349
|
+
num_boxes=puzzle.num_boxes,
|
350
|
+
dim_room=puzzle.dim_room,
|
351
|
+
max_steps=puzzle.max_steps,
|
352
|
+
shortest_path_length=puzzle.solution_length,
|
353
|
+
seed=seed, # Use the input seed, not the puzzle's generation seed
|
354
|
+
generation_params=f"verified_puzzle_id={puzzle.id}_from_seed={seed}",
|
355
|
+
)
|
356
|
+
|
357
|
+
# Use the puzzle data as the initial engine snapshot
|
358
|
+
initial_engine_snapshot = puzzle.to_engine_snapshot()
|
359
|
+
|
360
|
+
task_instance = SokobanTaskInstance(
|
361
|
+
id=instance_id,
|
362
|
+
impetus=impetus,
|
363
|
+
intent=intent,
|
364
|
+
metadata=metadata,
|
365
|
+
is_reproducible=True,
|
366
|
+
initial_engine_snapshot=initial_engine_snapshot,
|
367
|
+
)
|
368
|
+
|
369
|
+
return task_instance
|
370
|
+
|
371
|
+
|
372
|
+
# Example usage
|
373
|
+
if __name__ == "__main__":
|
374
|
+
import asyncio
|
375
|
+
import json
|
376
|
+
import os
|
377
|
+
|
378
|
+
NUM_INSTANCES_PER_DIFFICULTY = 2
|
379
|
+
# Updated path to examples/sokoban/dataset/instances.json
|
380
|
+
OUTPUT_FILE_PATH = "dataset/instances.json"
|
381
|
+
|
382
|
+
async def main():
|
383
|
+
taskset = await create_sokoban_taskset()
|
384
|
+
|
385
|
+
serialized = await asyncio.gather(*(inst.serialize() for inst in taskset.instances))
|
386
|
+
|
387
|
+
output_dir = os.path.dirname(OUTPUT_FILE_PATH)
|
388
|
+
if output_dir:
|
389
|
+
os.makedirs(output_dir, exist_ok=True)
|
390
|
+
|
391
|
+
with open(OUTPUT_FILE_PATH, "w") as f:
|
392
|
+
json.dump(serialized, f, indent=2)
|
393
|
+
print(f"Serialized {len(serialized)} instances to {OUTPUT_FILE_PATH}")
|
394
|
+
|
395
|
+
with open(OUTPUT_FILE_PATH, "r") as f:
|
396
|
+
read_serialized_data = json.load(f)
|
397
|
+
|
398
|
+
deserialized = await asyncio.gather(
|
399
|
+
*(SokobanTaskInstance.deserialize(data) for data in read_serialized_data)
|
400
|
+
)
|
401
|
+
print(f"Deserialized {len(deserialized)} instances.")
|
402
|
+
|
403
|
+
if any(inst is None for inst in deserialized):
|
404
|
+
print("Error: Deserialization returned None for some instances.")
|
405
|
+
for i, inst in enumerate(deserialized):
|
406
|
+
if inst is None:
|
407
|
+
print(
|
408
|
+
f"Instance at index {i} is None. Serialized data: {read_serialized_data[i]}"
|
409
|
+
)
|
410
|
+
return
|
411
|
+
|
412
|
+
val_ids = taskset.split_info.val_instance_ids
|
413
|
+
test_ids = taskset.split_info.test_instance_ids
|
414
|
+
all_ids = {inst.id for inst in deserialized}
|
415
|
+
train_ids = all_ids - val_ids - test_ids
|
416
|
+
|
417
|
+
train = [inst for inst in deserialized if inst.id in train_ids]
|
418
|
+
val = [inst for inst in deserialized if inst.id in val_ids]
|
419
|
+
test = [inst for inst in deserialized if inst.id in test_ids]
|
420
|
+
|
421
|
+
print(f"Train set ({len(train)} instances): {[str(i.id) for i in train]}")
|
422
|
+
print(f"Val set ({len(val)} instances): {[str(i.id) for i in val]}")
|
423
|
+
print(f"Test set ({len(test)} instances): {[str(i.id) for i in test]}")
|
424
|
+
|
425
|
+
asyncio.run(main())
|
@@ -0,0 +1 @@
|
|
1
|
+
# TicTacToe Environment Module
|