synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +579 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/wordle/__init__.py +29 -0
- synth_ai/environments/examples/wordle/engine.py +391 -0
- synth_ai/environments/examples/wordle/environment.py +154 -0
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
- synth_ai/environments/examples/wordle/taskset.py +222 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/environments/service/core_routes.py +38 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
- synth_ai/learning/prompts/mipro.py +273 -1
- synth_ai/learning/prompts/random_search.py +247 -0
- synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
- synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
- synth_ai/lm/injection.py +81 -0
- synth_ai/lm/overrides.py +204 -0
- synth_ai/lm/provider_support/anthropic.py +39 -12
- synth_ai/lm/provider_support/openai.py +31 -4
- synth_ai/lm/vendors/core/anthropic_api.py +16 -0
- synth_ai/lm/vendors/openai_standard.py +35 -5
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,391 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
5
|
+
from collections import Counter
|
6
|
+
import random
|
7
|
+
import string
|
8
|
+
|
9
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
10
|
+
from synth_ai.environments.reproducibility.core import IReproducibleEngine
|
11
|
+
from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent
|
12
|
+
from synth_ai.environments.environment.shared_engine import (
|
13
|
+
GetObservationCallable,
|
14
|
+
InternalObservation,
|
15
|
+
)
|
16
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
17
|
+
|
18
|
+
|
19
|
+
DEFAULT_SOLUTIONS = [
|
20
|
+
"cigar",
|
21
|
+
"rebut",
|
22
|
+
"sissy",
|
23
|
+
"humph",
|
24
|
+
"awake",
|
25
|
+
"blush",
|
26
|
+
"focal",
|
27
|
+
"evade",
|
28
|
+
"naval",
|
29
|
+
"serve",
|
30
|
+
"heath",
|
31
|
+
"dwarf",
|
32
|
+
"model",
|
33
|
+
"karma",
|
34
|
+
"stink",
|
35
|
+
"grade",
|
36
|
+
"quiet",
|
37
|
+
"bench",
|
38
|
+
"abate",
|
39
|
+
"feign",
|
40
|
+
"major",
|
41
|
+
"death",
|
42
|
+
"fresh",
|
43
|
+
"crust",
|
44
|
+
"stool",
|
45
|
+
"colon",
|
46
|
+
"abase",
|
47
|
+
"marry",
|
48
|
+
"react",
|
49
|
+
"batty",
|
50
|
+
"pride",
|
51
|
+
"floss",
|
52
|
+
"helix",
|
53
|
+
"croak",
|
54
|
+
"staff",
|
55
|
+
"paper",
|
56
|
+
"unfed",
|
57
|
+
"whelp",
|
58
|
+
"trawl",
|
59
|
+
"outdo",
|
60
|
+
"adobe",
|
61
|
+
"crazy",
|
62
|
+
"sower",
|
63
|
+
"repay",
|
64
|
+
"digit",
|
65
|
+
"crate",
|
66
|
+
"cluck",
|
67
|
+
"spike",
|
68
|
+
"mimic",
|
69
|
+
"pound",
|
70
|
+
]
|
71
|
+
|
72
|
+
|
73
|
+
def _sanitize(word: str) -> str:
|
74
|
+
w = word.strip().lower()
|
75
|
+
if not w or not all(c in string.ascii_lowercase for c in w):
|
76
|
+
raise ValueError("word must contain only a–z letters")
|
77
|
+
return w
|
78
|
+
|
79
|
+
|
80
|
+
def _score_guess(guess: str, target: str) -> str:
|
81
|
+
res = ["B"] * len(target)
|
82
|
+
counts = Counter(target)
|
83
|
+
for i, ch in enumerate(guess):
|
84
|
+
if ch == target[i]:
|
85
|
+
res[i] = "G"
|
86
|
+
counts[ch] -= 1
|
87
|
+
for i, ch in enumerate(guess):
|
88
|
+
if res[i] == "G":
|
89
|
+
continue
|
90
|
+
if counts.get(ch, 0) > 0:
|
91
|
+
res[i] = "Y"
|
92
|
+
counts[ch] -= 1
|
93
|
+
return "".join(res)
|
94
|
+
|
95
|
+
|
96
|
+
@dataclass
|
97
|
+
class WordlePublicState:
|
98
|
+
word_length: int
|
99
|
+
remaining_guesses: int
|
100
|
+
max_guesses: int
|
101
|
+
guesses: List[str]
|
102
|
+
feedback: List[str] # Parallel to guesses; strings of 'G/Y/B'
|
103
|
+
last_feedback: Optional[str]
|
104
|
+
last_guess: Optional[str]
|
105
|
+
terminated: bool
|
106
|
+
status: str # "in_progress" | "won" | "lost"
|
107
|
+
|
108
|
+
@property
|
109
|
+
def board_text(self) -> str:
|
110
|
+
if not self.guesses:
|
111
|
+
return "(no guesses yet)"
|
112
|
+
lines = []
|
113
|
+
for g, fb in zip(self.guesses, self.feedback):
|
114
|
+
spaced = " ".join(list(fb))
|
115
|
+
lines.append(f"{g.upper()} | {spaced}")
|
116
|
+
return "\n".join(lines)
|
117
|
+
|
118
|
+
|
119
|
+
@dataclass
|
120
|
+
class WordlePrivateState:
|
121
|
+
reward_last: float
|
122
|
+
total_reward: float
|
123
|
+
terminated: bool
|
124
|
+
truncated: bool
|
125
|
+
|
126
|
+
|
127
|
+
@dataclass
|
128
|
+
class WordleEngineSnapshot(StatefulEngineSnapshot):
|
129
|
+
task_instance_dict: Dict
|
130
|
+
engine_snapshot: Dict
|
131
|
+
|
132
|
+
|
133
|
+
class WordleWinComponent(RewardComponent):
|
134
|
+
async def score(self, state: WordlePublicState, action: Any) -> float:
|
135
|
+
return 1.0 if state.status == "won" else 0.0
|
136
|
+
|
137
|
+
|
138
|
+
class WordleInvalidGuessComponent(RewardComponent):
|
139
|
+
def __init__(self) -> None:
|
140
|
+
self.invalid_attempted = False
|
141
|
+
|
142
|
+
async def score(self, state: WordlePublicState, action: Any) -> float:
|
143
|
+
if self.invalid_attempted:
|
144
|
+
self.invalid_attempted = False
|
145
|
+
return -1.0
|
146
|
+
return 0.0
|
147
|
+
|
148
|
+
|
149
|
+
class WordleEngine(StatefulEngine, IReproducibleEngine):
|
150
|
+
def __init__(self, task_instance: TaskInstance):
|
151
|
+
self.task_instance = task_instance
|
152
|
+
|
153
|
+
# Read config from metadata
|
154
|
+
md = getattr(task_instance, "metadata", None)
|
155
|
+
self.word_length: int = getattr(md, "word_length", 5) if md else 5
|
156
|
+
self.max_guesses: int = getattr(md, "max_guesses", 6) if md else 6
|
157
|
+
self.enforce_wordlist: bool = getattr(md, "enforce_wordlist", False) if md else False
|
158
|
+
# Toggle: whether invalid actions consume a turn (default True)
|
159
|
+
self.consume_invalid_attempts: bool = getattr(md, "consume_invalid_attempts", True) if md else True
|
160
|
+
|
161
|
+
self.base_word_list: List[str] = [
|
162
|
+
w for w in DEFAULT_SOLUTIONS if len(w) == self.word_length
|
163
|
+
] or [w for w in DEFAULT_SOLUTIONS if len(w) == 5]
|
164
|
+
|
165
|
+
# Target selection: prefer explicit target_word in metadata; else pick deterministically by seed
|
166
|
+
self.fixed_target: Optional[str] = _sanitize(getattr(md, "target_word", "")) if md and getattr(md, "target_word", None) else None
|
167
|
+
self.seed: Optional[int] = getattr(md, "seed", None) if md else None
|
168
|
+
|
169
|
+
# Runtime state
|
170
|
+
self.target: Optional[str] = None
|
171
|
+
self.guesses: List[str] = []
|
172
|
+
self.feedback: List[str] = []
|
173
|
+
self.remaining_guesses: int = self.max_guesses
|
174
|
+
self.status: str = "in_progress"
|
175
|
+
self.terminated: bool = False
|
176
|
+
self.total_reward: float = 0.0
|
177
|
+
|
178
|
+
# Rewards
|
179
|
+
self.invalid_component = WordleInvalidGuessComponent()
|
180
|
+
self.reward_stack = RewardStack([WordleWinComponent(), self.invalid_component])
|
181
|
+
|
182
|
+
async def _reset_engine(self, *, seed: int | None = None) -> Tuple[WordlePrivateState, WordlePublicState]:
|
183
|
+
if seed is None:
|
184
|
+
seed = self.seed
|
185
|
+
if seed is not None and self.fixed_target is None:
|
186
|
+
random.seed(seed)
|
187
|
+
self.target = self.fixed_target or random.choice(self.base_word_list)
|
188
|
+
self.guesses = []
|
189
|
+
self.feedback = []
|
190
|
+
self.remaining_guesses = self.max_guesses
|
191
|
+
self.status = "in_progress"
|
192
|
+
self.terminated = False
|
193
|
+
self.total_reward = 0.0
|
194
|
+
|
195
|
+
pub = WordlePublicState(
|
196
|
+
word_length=self.word_length,
|
197
|
+
remaining_guesses=self.remaining_guesses,
|
198
|
+
max_guesses=self.max_guesses,
|
199
|
+
guesses=[],
|
200
|
+
feedback=[],
|
201
|
+
last_feedback=None,
|
202
|
+
last_guess=None,
|
203
|
+
terminated=False,
|
204
|
+
status=self.status,
|
205
|
+
)
|
206
|
+
priv = WordlePrivateState(
|
207
|
+
reward_last=0.0,
|
208
|
+
total_reward=0.0,
|
209
|
+
terminated=False,
|
210
|
+
truncated=False,
|
211
|
+
)
|
212
|
+
return priv, pub
|
213
|
+
|
214
|
+
async def _step_engine(self, action: str) -> Tuple[WordlePrivateState, WordlePublicState]:
|
215
|
+
assert self.target is not None
|
216
|
+
guess = _sanitize(action)
|
217
|
+
|
218
|
+
# Validate
|
219
|
+
if len(guess) != self.word_length or (
|
220
|
+
self.enforce_wordlist and guess not in self.base_word_list
|
221
|
+
):
|
222
|
+
# Penalize invalid action; do not consume a guess
|
223
|
+
self.invalid_component.invalid_attempted = True
|
224
|
+
if self.consume_invalid_attempts:
|
225
|
+
# consume a turn on invalid guesses
|
226
|
+
if self.remaining_guesses > 0:
|
227
|
+
self.remaining_guesses -= 1
|
228
|
+
if self.remaining_guesses == 0:
|
229
|
+
self.status = "lost"
|
230
|
+
self.terminated = True
|
231
|
+
pub = WordlePublicState(
|
232
|
+
word_length=self.word_length,
|
233
|
+
remaining_guesses=self.remaining_guesses,
|
234
|
+
max_guesses=self.max_guesses,
|
235
|
+
guesses=self.guesses.copy(),
|
236
|
+
feedback=self.feedback.copy(),
|
237
|
+
last_feedback=self.feedback[-1] if self.feedback else None,
|
238
|
+
last_guess=self.guesses[-1] if self.guesses else None,
|
239
|
+
terminated=self.terminated,
|
240
|
+
status=self.status,
|
241
|
+
)
|
242
|
+
reward = await self.reward_stack.step_reward(pub, action)
|
243
|
+
self.total_reward += reward
|
244
|
+
priv = WordlePrivateState(
|
245
|
+
reward_last=reward,
|
246
|
+
total_reward=self.total_reward,
|
247
|
+
terminated=self.terminated,
|
248
|
+
truncated=False,
|
249
|
+
)
|
250
|
+
return priv, pub
|
251
|
+
|
252
|
+
fb = _score_guess(guess, self.target)
|
253
|
+
self.guesses.append(guess)
|
254
|
+
self.feedback.append(fb)
|
255
|
+
self.remaining_guesses -= 1
|
256
|
+
|
257
|
+
if guess == self.target:
|
258
|
+
self.status = "won"
|
259
|
+
self.terminated = True
|
260
|
+
elif self.remaining_guesses == 0:
|
261
|
+
self.status = "lost"
|
262
|
+
self.terminated = True
|
263
|
+
else:
|
264
|
+
self.status = "in_progress"
|
265
|
+
|
266
|
+
pub = WordlePublicState(
|
267
|
+
word_length=self.word_length,
|
268
|
+
remaining_guesses=self.remaining_guesses,
|
269
|
+
max_guesses=self.max_guesses,
|
270
|
+
guesses=self.guesses.copy(),
|
271
|
+
feedback=self.feedback.copy(),
|
272
|
+
last_feedback=fb,
|
273
|
+
last_guess=guess,
|
274
|
+
terminated=self.terminated,
|
275
|
+
status=self.status,
|
276
|
+
)
|
277
|
+
|
278
|
+
reward = await self.reward_stack.step_reward(pub, action)
|
279
|
+
self.total_reward += reward
|
280
|
+
priv = WordlePrivateState(
|
281
|
+
reward_last=reward,
|
282
|
+
total_reward=self.total_reward,
|
283
|
+
terminated=self.terminated,
|
284
|
+
truncated=False,
|
285
|
+
)
|
286
|
+
return priv, pub
|
287
|
+
|
288
|
+
async def _serialize_engine(self) -> WordleEngineSnapshot:
|
289
|
+
return WordleEngineSnapshot(
|
290
|
+
task_instance_dict=await self.task_instance.serialize(),
|
291
|
+
engine_snapshot={
|
292
|
+
"word_length": self.word_length,
|
293
|
+
"max_guesses": self.max_guesses,
|
294
|
+
"enforce_wordlist": self.enforce_wordlist,
|
295
|
+
"consume_invalid_attempts": self.consume_invalid_attempts,
|
296
|
+
"base_word_list": self.base_word_list,
|
297
|
+
"fixed_target": self.fixed_target,
|
298
|
+
"seed": self.seed,
|
299
|
+
"target": self.target,
|
300
|
+
"guesses": self.guesses,
|
301
|
+
"feedback": self.feedback,
|
302
|
+
"remaining_guesses": self.remaining_guesses,
|
303
|
+
"status": self.status,
|
304
|
+
"terminated": self.terminated,
|
305
|
+
"total_reward": self.total_reward,
|
306
|
+
},
|
307
|
+
)
|
308
|
+
|
309
|
+
@classmethod
|
310
|
+
async def _deserialize_engine(cls, snapshot: WordleEngineSnapshot) -> "WordleEngine":
|
311
|
+
task_instance = await TaskInstance.deserialize(snapshot.task_instance_dict)
|
312
|
+
engine = cls(task_instance)
|
313
|
+
s = snapshot.engine_snapshot
|
314
|
+
engine.word_length = s["word_length"]
|
315
|
+
engine.max_guesses = s["max_guesses"]
|
316
|
+
engine.enforce_wordlist = s["enforce_wordlist"]
|
317
|
+
engine.consume_invalid_attempts = s.get("consume_invalid_attempts", True)
|
318
|
+
engine.base_word_list = s.get("base_word_list", engine.base_word_list)
|
319
|
+
engine.fixed_target = s.get("fixed_target")
|
320
|
+
engine.seed = s.get("seed")
|
321
|
+
engine.target = s.get("target")
|
322
|
+
engine.guesses = s.get("guesses", [])
|
323
|
+
engine.feedback = s.get("feedback", [])
|
324
|
+
engine.remaining_guesses = s.get("remaining_guesses", engine.max_guesses)
|
325
|
+
engine.status = s.get("status", "in_progress")
|
326
|
+
engine.terminated = s.get("terminated", False)
|
327
|
+
engine.total_reward = s.get("total_reward", 0.0)
|
328
|
+
return engine
|
329
|
+
|
330
|
+
def get_current_states_for_observation(self) -> Tuple[WordlePrivateState, WordlePublicState]:
|
331
|
+
pub = WordlePublicState(
|
332
|
+
word_length=self.word_length,
|
333
|
+
remaining_guesses=self.remaining_guesses,
|
334
|
+
max_guesses=self.max_guesses,
|
335
|
+
guesses=self.guesses.copy(),
|
336
|
+
feedback=self.feedback.copy(),
|
337
|
+
last_feedback=self.feedback[-1] if self.feedback else None,
|
338
|
+
last_guess=self.guesses[-1] if self.guesses else None,
|
339
|
+
terminated=self.terminated,
|
340
|
+
status=self.status,
|
341
|
+
)
|
342
|
+
priv = WordlePrivateState(
|
343
|
+
reward_last=0.0,
|
344
|
+
total_reward=self.total_reward,
|
345
|
+
terminated=self.terminated,
|
346
|
+
truncated=False,
|
347
|
+
)
|
348
|
+
return priv, pub
|
349
|
+
|
350
|
+
|
351
|
+
class SynthWordleObservationCallable(GetObservationCallable):
|
352
|
+
async def get_observation(
|
353
|
+
self, pub: WordlePublicState, priv: WordlePrivateState
|
354
|
+
) -> InternalObservation:
|
355
|
+
header = f"WORDLE ({pub.word_length} letters, {pub.max_guesses} max guesses)"
|
356
|
+
lines = [
|
357
|
+
header,
|
358
|
+
"Submit a single English word (letters only).",
|
359
|
+
"",
|
360
|
+
pub.board_text,
|
361
|
+
"",
|
362
|
+
]
|
363
|
+
if pub.status == "in_progress":
|
364
|
+
lines.append(f"You have {pub.remaining_guesses} guesses left.")
|
365
|
+
elif pub.status == "won":
|
366
|
+
lines.append("You guessed the word! ✅")
|
367
|
+
else:
|
368
|
+
lines.append("Out of guesses. ❌")
|
369
|
+
|
370
|
+
return {
|
371
|
+
"text": "\n".join(lines),
|
372
|
+
"status": pub.status,
|
373
|
+
"remaining_guesses": pub.remaining_guesses,
|
374
|
+
"guesses": pub.guesses,
|
375
|
+
"feedback": pub.feedback,
|
376
|
+
"reward_last": priv.reward_last,
|
377
|
+
"total_reward": priv.total_reward,
|
378
|
+
"terminated": pub.terminated,
|
379
|
+
}
|
380
|
+
|
381
|
+
|
382
|
+
class SynthWordleCheckpointObservationCallable(GetObservationCallable):
|
383
|
+
async def get_observation(
|
384
|
+
self, pub: WordlePublicState, priv: WordlePrivateState
|
385
|
+
) -> InternalObservation:
|
386
|
+
return {
|
387
|
+
"board_text_final": pub.board_text,
|
388
|
+
"status_final": pub.status,
|
389
|
+
"total_reward": priv.total_reward,
|
390
|
+
"terminated": pub.terminated,
|
391
|
+
}
|
@@ -0,0 +1,154 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Optional, Dict, Any, Union
|
4
|
+
from pydantic import BaseModel, Field
|
5
|
+
|
6
|
+
from synth_ai.environments.stateful.core import StatefulEnvironment
|
7
|
+
from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
|
8
|
+
from synth_ai.environments.environment.shared_engine import (
|
9
|
+
GetObservationCallable,
|
10
|
+
InternalObservation,
|
11
|
+
)
|
12
|
+
from synth_ai.environments.environment.tools import (
|
13
|
+
AbstractTool,
|
14
|
+
EnvToolCall,
|
15
|
+
ToolResult,
|
16
|
+
)
|
17
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
18
|
+
|
19
|
+
from .engine import (
|
20
|
+
WordleEngine,
|
21
|
+
WordlePublicState,
|
22
|
+
WordlePrivateState,
|
23
|
+
WordleEngineSnapshot,
|
24
|
+
SynthWordleObservationCallable,
|
25
|
+
SynthWordleCheckpointObservationCallable,
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class WordleActionInput(BaseModel):
|
30
|
+
guess: str = Field(..., description="Your word guess (letters only)")
|
31
|
+
|
32
|
+
|
33
|
+
class WordleInteractTool(AbstractTool):
|
34
|
+
name = "interact"
|
35
|
+
description = "Submit a word guess to the Wordle environment."
|
36
|
+
call_schema = WordleActionInput
|
37
|
+
result_schema = ToolResult
|
38
|
+
|
39
|
+
def __init__(self, engine: WordleEngine):
|
40
|
+
self.engine = engine
|
41
|
+
|
42
|
+
async def __call__(self, call: EnvToolCall) -> ToolResult:
|
43
|
+
try:
|
44
|
+
validated = self.call_schema(**call.args)
|
45
|
+
priv, pub = await self.engine._step_engine(validated.guess)
|
46
|
+
return ToolResult(ok=True, payload={"public_state": pub, "private_state": priv})
|
47
|
+
except Exception as e:
|
48
|
+
# Return current state with error message
|
49
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
50
|
+
return ToolResult(ok=False, error=str(e), payload={"public_state": pub, "private_state": priv})
|
51
|
+
|
52
|
+
|
53
|
+
class WordleEnvironment(StatefulEnvironment, ReproducibleEnvironment[WordleEngine]):
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
task_instance: TaskInstance,
|
57
|
+
custom_step_obs: Optional[GetObservationCallable] = None,
|
58
|
+
custom_ckpt_obs: Optional[GetObservationCallable] = None,
|
59
|
+
) -> None:
|
60
|
+
self.name = "Wordle"
|
61
|
+
self.task_instance = task_instance
|
62
|
+
self.custom_step_observation_callable = custom_step_obs or SynthWordleObservationCallable()
|
63
|
+
self.custom_checkpoint_observation_callable = (
|
64
|
+
custom_ckpt_obs or SynthWordleCheckpointObservationCallable()
|
65
|
+
)
|
66
|
+
self.engine = WordleEngine(task_instance)
|
67
|
+
self._interact_tool = WordleInteractTool(self.engine)
|
68
|
+
|
69
|
+
async def initialize(self) -> InternalObservation:
|
70
|
+
priv, pub = await self.engine._reset_engine()
|
71
|
+
return await self._to_observation(priv, pub, self.custom_step_observation_callable)
|
72
|
+
|
73
|
+
async def step(self, tool_calls) -> InternalObservation:
|
74
|
+
validated_call = self.validate_tool_calls(tool_calls)
|
75
|
+
result = await self._interact_tool(validated_call)
|
76
|
+
if result.ok:
|
77
|
+
priv = result.payload["private_state"]
|
78
|
+
pub = result.payload["public_state"]
|
79
|
+
return await self._to_observation(priv, pub, self.custom_step_observation_callable)
|
80
|
+
else:
|
81
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
82
|
+
return await self._to_observation(
|
83
|
+
priv, pub, self.custom_step_observation_callable, extra_obs={"error": result.error}
|
84
|
+
)
|
85
|
+
|
86
|
+
async def checkpoint(self) -> InternalObservation:
|
87
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
88
|
+
return await self._to_observation(priv, pub, self.custom_checkpoint_observation_callable)
|
89
|
+
|
90
|
+
async def terminate(self) -> InternalObservation:
|
91
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
92
|
+
pub.terminated = True
|
93
|
+
priv.terminated = True
|
94
|
+
return await self._to_observation(priv, pub, self.custom_checkpoint_observation_callable)
|
95
|
+
|
96
|
+
def validate_tool_calls(self, tool_calls) -> EnvToolCall:
|
97
|
+
# Accept EnvToolCall, dict-like, or list formats similar to other envs
|
98
|
+
if isinstance(tool_calls, EnvToolCall):
|
99
|
+
validated = tool_calls
|
100
|
+
elif isinstance(tool_calls, dict):
|
101
|
+
if "tool" in tool_calls:
|
102
|
+
validated = EnvToolCall(tool=tool_calls["tool"], args=tool_calls.get("args", {}))
|
103
|
+
elif "name" in tool_calls:
|
104
|
+
validated = EnvToolCall(tool=tool_calls["name"], args=tool_calls.get("parameters", {}))
|
105
|
+
elif "function" in tool_calls:
|
106
|
+
validated = EnvToolCall(
|
107
|
+
tool=tool_calls["function"]["name"], args=tool_calls["function"].get("arguments", {})
|
108
|
+
)
|
109
|
+
else:
|
110
|
+
# Treat remaining keys as args; default tool name
|
111
|
+
validated = EnvToolCall(tool="interact", args=tool_calls)
|
112
|
+
elif isinstance(tool_calls, list):
|
113
|
+
if len(tool_calls) == 0:
|
114
|
+
raise ValueError("Empty tool calls list")
|
115
|
+
validated = self.validate_tool_calls(tool_calls[0])
|
116
|
+
else:
|
117
|
+
# Assume it's a raw guess string
|
118
|
+
validated = EnvToolCall(tool="interact", args={"guess": str(tool_calls)})
|
119
|
+
|
120
|
+
if validated.tool != "interact":
|
121
|
+
raise ValueError(f"Unknown tool: {validated.tool}")
|
122
|
+
# Normalize: allow 'action' key synonymous with 'guess'
|
123
|
+
args = validated.args
|
124
|
+
if "action" in args and "guess" not in args:
|
125
|
+
args = {"guess": args["action"]}
|
126
|
+
return EnvToolCall(tool="interact", args=args)
|
127
|
+
|
128
|
+
async def _to_observation(
|
129
|
+
self,
|
130
|
+
priv: WordlePrivateState,
|
131
|
+
pub: WordlePublicState,
|
132
|
+
obs_cb: Optional[GetObservationCallable],
|
133
|
+
extra_obs: Optional[Dict[str, Any]] = None,
|
134
|
+
) -> InternalObservation:
|
135
|
+
if obs_cb:
|
136
|
+
obs = await obs_cb.get_observation(pub, priv)
|
137
|
+
else:
|
138
|
+
obs: InternalObservation = {}
|
139
|
+
if extra_obs and isinstance(obs, dict):
|
140
|
+
obs.update(extra_obs)
|
141
|
+
return obs
|
142
|
+
|
143
|
+
async def _serialize_engine(self) -> WordleEngineSnapshot:
|
144
|
+
return await self.engine._serialize_engine()
|
145
|
+
|
146
|
+
@classmethod
|
147
|
+
async def _deserialize_engine(
|
148
|
+
cls, snapshot: WordleEngineSnapshot, task_instance: TaskInstance
|
149
|
+
) -> "WordleEnvironment":
|
150
|
+
env = cls(task_instance)
|
151
|
+
env.engine = await WordleEngine._deserialize_engine(snapshot)
|
152
|
+
env._interact_tool = WordleInteractTool(env.engine)
|
153
|
+
return env
|
154
|
+
|
@@ -0,0 +1,75 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Generate a fixed Wordle instances.json using the "wordfreq" package.
|
4
|
+
|
5
|
+
Usage:
|
6
|
+
pip install wordfreq
|
7
|
+
python -m synth_ai.environments.examples.wordle.helpers.generate_instances_wordfreq \
|
8
|
+
--count 500 --min-zipf 3.0 --outfile synth_ai/environments/examples/wordle/instances.json
|
9
|
+
|
10
|
+
This script writes a deterministic list of 5-letter English words ranked by frequency.
|
11
|
+
Commit the resulting instances.json to remove runtime dependencies.
|
12
|
+
"""
|
13
|
+
from __future__ import annotations
|
14
|
+
|
15
|
+
import argparse
|
16
|
+
import json
|
17
|
+
import re
|
18
|
+
from typing import List
|
19
|
+
|
20
|
+
from wordfreq import zipf_frequency, top_n_list
|
21
|
+
|
22
|
+
|
23
|
+
def build_word_list(count: int, length: int, min_zipf: float, wordlist: str = "large") -> List[str]:
|
24
|
+
N = max(count * 20, 5000)
|
25
|
+
cands = [w.lower() for w in top_n_list("en", N, wordlist=wordlist)]
|
26
|
+
cands = [w for w in cands if len(w) == length and re.fullmatch(r"[a-z]+", w)]
|
27
|
+
scored = [(w, zipf_frequency(w, "en")) for w in cands]
|
28
|
+
scored = [p for p in scored if p[1] >= float(min_zipf)]
|
29
|
+
scored.sort(key=lambda t: (-t[1], t[0]))
|
30
|
+
out: List[str] = []
|
31
|
+
seen = set()
|
32
|
+
for w, _ in scored:
|
33
|
+
if w in seen:
|
34
|
+
continue
|
35
|
+
seen.add(w)
|
36
|
+
out.append(w)
|
37
|
+
if len(out) >= count:
|
38
|
+
break
|
39
|
+
if len(out) < count:
|
40
|
+
raise RuntimeError(
|
41
|
+
f"Insufficient {length}-letter words from wordfreq after filtering ({len(out)} < {count})."
|
42
|
+
)
|
43
|
+
return out
|
44
|
+
|
45
|
+
|
46
|
+
def main():
|
47
|
+
ap = argparse.ArgumentParser()
|
48
|
+
ap.add_argument("--count", type=int, default=500)
|
49
|
+
ap.add_argument("--length", type=int, default=5)
|
50
|
+
ap.add_argument("--min-zipf", type=float, default=3.0)
|
51
|
+
ap.add_argument("--wordlist", type=str, default="large")
|
52
|
+
ap.add_argument("--outfile", type=str, required=True)
|
53
|
+
args = ap.parse_args()
|
54
|
+
|
55
|
+
words = build_word_list(args.count, args.length, args.min_zipf, args.wordlist)
|
56
|
+
|
57
|
+
data = {
|
58
|
+
"name": f"Wordle Fixed TaskSet ({args.count} English words)",
|
59
|
+
"description": f"{len(words)} {args.length}-letter English words ranked by frequency (wordfreq).",
|
60
|
+
"defaults": {
|
61
|
+
"word_length": args.length,
|
62
|
+
"max_guesses": 6,
|
63
|
+
"enforce_wordlist": True,
|
64
|
+
"consume_invalid_attempts": True,
|
65
|
+
},
|
66
|
+
"instances": [{"target_word": w} for w in words],
|
67
|
+
}
|
68
|
+
|
69
|
+
with open(args.outfile, "w") as f:
|
70
|
+
json.dump(data, f, indent=2)
|
71
|
+
print(f"Wrote {len(words)} words to {args.outfile}")
|
72
|
+
|
73
|
+
|
74
|
+
if __name__ == "__main__":
|
75
|
+
main()
|