synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +575 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +104 -6
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,368 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from typing import Dict, Any, Optional, Tuple
|
6
|
+
|
7
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
8
|
+
from synth_ai.environments.reproducibility.core import IReproducibleEngine
|
9
|
+
from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent
|
10
|
+
from synth_ai.environments.environment.shared_engine import (
|
11
|
+
GetObservationCallable,
|
12
|
+
InternalObservation,
|
13
|
+
)
|
14
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
15
|
+
|
16
|
+
|
17
|
+
# Action mapping: coordinate strings to board indices
|
18
|
+
COORD_TO_IDX = {
|
19
|
+
"A1": 0,
|
20
|
+
"A2": 1,
|
21
|
+
"A3": 2,
|
22
|
+
"B1": 3,
|
23
|
+
"B2": 4,
|
24
|
+
"B3": 5,
|
25
|
+
"C1": 6,
|
26
|
+
"C2": 7,
|
27
|
+
"C3": 8,
|
28
|
+
}
|
29
|
+
IDX_TO_COORD = {v: k for k, v in COORD_TO_IDX.items()}
|
30
|
+
|
31
|
+
# Win condition patterns (row, col, diagonal indices)
|
32
|
+
WIN_PATTERNS = [
|
33
|
+
[0, 1, 2],
|
34
|
+
[3, 4, 5],
|
35
|
+
[6, 7, 8], # rows
|
36
|
+
[0, 3, 6],
|
37
|
+
[1, 4, 7],
|
38
|
+
[2, 5, 8], # columns
|
39
|
+
[0, 4, 8],
|
40
|
+
[2, 4, 6], # diagonals
|
41
|
+
]
|
42
|
+
|
43
|
+
# Player mappings
|
44
|
+
PLAYER_MARKS = {"X": 1, "O": 2}
|
45
|
+
MARK_TO_PLAYER = {1: "X", 2: "O", 0: " "}
|
46
|
+
|
47
|
+
|
48
|
+
@dataclass
|
49
|
+
class TicTacToePublicState:
|
50
|
+
board: np.ndarray # 3x3 array: 0=empty, 1=X, 2=O
|
51
|
+
current_player: str # "X" or "O"
|
52
|
+
last_move: Optional[str] # "A1", "B2", etc.
|
53
|
+
winner: Optional[str] # None, "X", "O", or "draw"
|
54
|
+
move_count: int # Number of moves made
|
55
|
+
max_moves: int # Always 9 for TicTacToe
|
56
|
+
terminated: bool # Game finished
|
57
|
+
|
58
|
+
def diff(self, prev_state: "TicTacToePublicState") -> Dict[str, Any]:
|
59
|
+
differences = {}
|
60
|
+
if not np.array_equal(self.board, prev_state.board):
|
61
|
+
differences["board"] = self.board.tolist()
|
62
|
+
if self.current_player != prev_state.current_player:
|
63
|
+
differences["current_player"] = self.current_player
|
64
|
+
if self.last_move != prev_state.last_move:
|
65
|
+
differences["last_move"] = self.last_move
|
66
|
+
if self.winner != prev_state.winner:
|
67
|
+
differences["winner"] = self.winner
|
68
|
+
if self.move_count != prev_state.move_count:
|
69
|
+
differences["move_count"] = self.move_count
|
70
|
+
if self.terminated != prev_state.terminated:
|
71
|
+
differences["terminated"] = self.terminated
|
72
|
+
return differences
|
73
|
+
|
74
|
+
@property
|
75
|
+
def board_text(self) -> str:
|
76
|
+
lines = []
|
77
|
+
lines.append(" A B C")
|
78
|
+
for i in range(3):
|
79
|
+
row_marks = []
|
80
|
+
for j in range(3):
|
81
|
+
mark = MARK_TO_PLAYER[self.board[i * 3 + j]]
|
82
|
+
row_marks.append(mark)
|
83
|
+
lines.append(f"{i + 1} {' '.join(row_marks)}")
|
84
|
+
return "\n".join(lines)
|
85
|
+
|
86
|
+
|
87
|
+
@dataclass
|
88
|
+
class TicTacToePrivateState:
|
89
|
+
reward_last: float
|
90
|
+
total_reward: float
|
91
|
+
terminated: bool
|
92
|
+
truncated: bool
|
93
|
+
|
94
|
+
def diff(self, prev_state: "TicTacToePrivateState") -> Dict[str, Any]:
|
95
|
+
differences = {}
|
96
|
+
if self.reward_last != prev_state.reward_last:
|
97
|
+
differences["reward_last"] = self.reward_last
|
98
|
+
if self.total_reward != prev_state.total_reward:
|
99
|
+
differences["total_reward"] = self.total_reward
|
100
|
+
if self.terminated != prev_state.terminated:
|
101
|
+
differences["terminated"] = self.terminated
|
102
|
+
if self.truncated != prev_state.truncated:
|
103
|
+
differences["truncated"] = self.truncated
|
104
|
+
return differences
|
105
|
+
|
106
|
+
|
107
|
+
@dataclass
|
108
|
+
class TicTacToeEngineSnapshot(StatefulEngineSnapshot):
|
109
|
+
task_instance_dict: Dict
|
110
|
+
engine_snapshot: Dict
|
111
|
+
|
112
|
+
|
113
|
+
class TicTacToeWinComponent(RewardComponent):
|
114
|
+
def __init__(self, player_mark: str = "X"):
|
115
|
+
super().__init__()
|
116
|
+
self.player_mark = player_mark
|
117
|
+
|
118
|
+
async def score(self, state: TicTacToePublicState, action: Any) -> float:
|
119
|
+
if state.winner == self.player_mark:
|
120
|
+
return 1.0
|
121
|
+
elif state.winner and state.winner != "draw":
|
122
|
+
return -1.0 # Opponent won
|
123
|
+
return 0.0
|
124
|
+
|
125
|
+
|
126
|
+
class TicTacToeDrawComponent(RewardComponent):
|
127
|
+
async def score(self, state: TicTacToePublicState, action: Any) -> float:
|
128
|
+
if state.winner == "draw":
|
129
|
+
return 0.0
|
130
|
+
return 0.0
|
131
|
+
|
132
|
+
|
133
|
+
class TicTacToeIllegalMoveComponent(RewardComponent):
|
134
|
+
def __init__(self):
|
135
|
+
self.illegal_move_attempted = False
|
136
|
+
|
137
|
+
async def score(self, state: TicTacToePublicState, action: Any) -> float:
|
138
|
+
if self.illegal_move_attempted:
|
139
|
+
self.illegal_move_attempted = False
|
140
|
+
return -1.0
|
141
|
+
return 0.0
|
142
|
+
|
143
|
+
|
144
|
+
class TicTacToeEngine(StatefulEngine, IReproducibleEngine):
|
145
|
+
def __init__(self, task_instance: TaskInstance):
|
146
|
+
self.task_instance = task_instance
|
147
|
+
self.illegal_move_component = TicTacToeIllegalMoveComponent()
|
148
|
+
|
149
|
+
# Determine which player the agent is controlling
|
150
|
+
agent_player = "X" # Default to X
|
151
|
+
if hasattr(task_instance, "metadata") and hasattr(
|
152
|
+
task_instance.metadata, "starting_player"
|
153
|
+
):
|
154
|
+
agent_player = task_instance.metadata.starting_player
|
155
|
+
|
156
|
+
self.reward_stack = RewardStack(
|
157
|
+
[
|
158
|
+
TicTacToeWinComponent(player_mark=agent_player),
|
159
|
+
TicTacToeDrawComponent(),
|
160
|
+
self.illegal_move_component,
|
161
|
+
]
|
162
|
+
)
|
163
|
+
|
164
|
+
# Initialize game state
|
165
|
+
self.board = np.zeros(9, dtype=int)
|
166
|
+
self.current_player = "X"
|
167
|
+
self.last_move = None
|
168
|
+
self.winner = None
|
169
|
+
self.move_count = 0
|
170
|
+
self.terminated = False
|
171
|
+
self.total_reward = 0.0
|
172
|
+
|
173
|
+
# Apply any pre-moves from task instance metadata
|
174
|
+
if hasattr(task_instance, "metadata") and hasattr(task_instance.metadata, "opening_moves"):
|
175
|
+
for move in task_instance.metadata.opening_moves:
|
176
|
+
self._apply_move(move)
|
177
|
+
|
178
|
+
async def _reset_engine(
|
179
|
+
self, *, seed: int | None = None
|
180
|
+
) -> Tuple[TicTacToePrivateState, TicTacToePublicState]:
|
181
|
+
self.board = np.zeros(9, dtype=int)
|
182
|
+
self.current_player = "X"
|
183
|
+
self.last_move = None
|
184
|
+
self.winner = None
|
185
|
+
self.move_count = 0
|
186
|
+
self.terminated = False
|
187
|
+
self.total_reward = 0.0
|
188
|
+
|
189
|
+
# Apply any pre-moves from task instance metadata
|
190
|
+
if hasattr(self.task_instance, "metadata") and hasattr(
|
191
|
+
self.task_instance.metadata, "opening_moves"
|
192
|
+
):
|
193
|
+
for move in self.task_instance.metadata.opening_moves:
|
194
|
+
self._apply_move(move)
|
195
|
+
|
196
|
+
public_state = TicTacToePublicState(
|
197
|
+
board=self.board.copy(),
|
198
|
+
current_player=self.current_player,
|
199
|
+
last_move=self.last_move,
|
200
|
+
winner=self.winner,
|
201
|
+
move_count=self.move_count,
|
202
|
+
max_moves=9,
|
203
|
+
terminated=self.terminated,
|
204
|
+
)
|
205
|
+
|
206
|
+
private_state = TicTacToePrivateState(
|
207
|
+
reward_last=0.0,
|
208
|
+
total_reward=self.total_reward,
|
209
|
+
terminated=self.terminated,
|
210
|
+
truncated=False,
|
211
|
+
)
|
212
|
+
|
213
|
+
return private_state, public_state
|
214
|
+
|
215
|
+
async def _step_engine(self, action: str) -> Tuple[TicTacToePrivateState, TicTacToePublicState]:
|
216
|
+
# Validate and apply move
|
217
|
+
if not self._is_valid_move(action, self.board):
|
218
|
+
self.illegal_move_component.illegal_move_attempted = True
|
219
|
+
self.terminated = True
|
220
|
+
else:
|
221
|
+
self._apply_move(action)
|
222
|
+
|
223
|
+
# Create public state
|
224
|
+
public_state = TicTacToePublicState(
|
225
|
+
board=self.board.copy(),
|
226
|
+
current_player=self.current_player,
|
227
|
+
last_move=self.last_move,
|
228
|
+
winner=self.winner,
|
229
|
+
move_count=self.move_count,
|
230
|
+
max_moves=9,
|
231
|
+
terminated=self.terminated,
|
232
|
+
)
|
233
|
+
|
234
|
+
# Calculate rewards
|
235
|
+
reward = await self.reward_stack.step_reward(public_state, action)
|
236
|
+
self.total_reward += reward
|
237
|
+
|
238
|
+
# Create private state
|
239
|
+
private_state = TicTacToePrivateState(
|
240
|
+
reward_last=reward,
|
241
|
+
total_reward=self.total_reward,
|
242
|
+
terminated=self.terminated,
|
243
|
+
truncated=False,
|
244
|
+
)
|
245
|
+
|
246
|
+
return private_state, public_state
|
247
|
+
|
248
|
+
def _apply_move(self, coord: str):
|
249
|
+
if coord not in COORD_TO_IDX:
|
250
|
+
return
|
251
|
+
|
252
|
+
idx = COORD_TO_IDX[coord]
|
253
|
+
if self.board[idx] == 0:
|
254
|
+
self.board[idx] = PLAYER_MARKS[self.current_player]
|
255
|
+
self.last_move = coord
|
256
|
+
self.move_count += 1
|
257
|
+
|
258
|
+
# Check for winner
|
259
|
+
self.winner = self._check_winner(self.board)
|
260
|
+
|
261
|
+
# Check if game is over
|
262
|
+
if self.winner is not None or self.move_count >= 9:
|
263
|
+
self.terminated = True
|
264
|
+
else:
|
265
|
+
# Switch players
|
266
|
+
self.current_player = "O" if self.current_player == "X" else "X"
|
267
|
+
|
268
|
+
def _check_winner(self, board: np.ndarray) -> Optional[str]:
|
269
|
+
# Check all win patterns
|
270
|
+
for pattern in WIN_PATTERNS:
|
271
|
+
values = [board[i] for i in pattern]
|
272
|
+
if values[0] != 0 and values[0] == values[1] == values[2]:
|
273
|
+
return MARK_TO_PLAYER[values[0]]
|
274
|
+
|
275
|
+
# Check for draw
|
276
|
+
if np.all(board != 0):
|
277
|
+
return "draw"
|
278
|
+
|
279
|
+
return None
|
280
|
+
|
281
|
+
def _is_valid_move(self, coord: str, board: np.ndarray) -> bool:
|
282
|
+
if coord not in COORD_TO_IDX:
|
283
|
+
return False
|
284
|
+
idx = COORD_TO_IDX[coord]
|
285
|
+
return board[idx] == 0
|
286
|
+
|
287
|
+
async def _serialize_engine(self) -> TicTacToeEngineSnapshot:
|
288
|
+
return TicTacToeEngineSnapshot(
|
289
|
+
task_instance_dict=await self.task_instance.serialize(),
|
290
|
+
engine_snapshot={
|
291
|
+
"board": self.board.tolist(),
|
292
|
+
"current_player": self.current_player,
|
293
|
+
"last_move": self.last_move,
|
294
|
+
"winner": self.winner,
|
295
|
+
"move_count": self.move_count,
|
296
|
+
"terminated": self.terminated,
|
297
|
+
"total_reward": self.total_reward,
|
298
|
+
},
|
299
|
+
)
|
300
|
+
|
301
|
+
@classmethod
|
302
|
+
async def _deserialize_engine(cls, snapshot: TicTacToeEngineSnapshot) -> "TicTacToeEngine":
|
303
|
+
task_instance = await TaskInstance.deserialize(snapshot.task_instance_dict)
|
304
|
+
engine = cls(task_instance)
|
305
|
+
|
306
|
+
# Restore state
|
307
|
+
engine.board = np.array(snapshot.engine_snapshot["board"])
|
308
|
+
engine.current_player = snapshot.engine_snapshot["current_player"]
|
309
|
+
engine.last_move = snapshot.engine_snapshot["last_move"]
|
310
|
+
engine.winner = snapshot.engine_snapshot["winner"]
|
311
|
+
engine.move_count = snapshot.engine_snapshot["move_count"]
|
312
|
+
engine.terminated = snapshot.engine_snapshot["terminated"]
|
313
|
+
engine.total_reward = snapshot.engine_snapshot["total_reward"]
|
314
|
+
|
315
|
+
return engine
|
316
|
+
|
317
|
+
def get_current_states_for_observation(
|
318
|
+
self,
|
319
|
+
) -> Tuple[TicTacToePrivateState, TicTacToePublicState]:
|
320
|
+
public_state = TicTacToePublicState(
|
321
|
+
board=self.board.copy(),
|
322
|
+
current_player=self.current_player,
|
323
|
+
last_move=self.last_move,
|
324
|
+
winner=self.winner,
|
325
|
+
move_count=self.move_count,
|
326
|
+
max_moves=9,
|
327
|
+
terminated=self.terminated,
|
328
|
+
)
|
329
|
+
|
330
|
+
private_state = TicTacToePrivateState(
|
331
|
+
reward_last=0.0,
|
332
|
+
total_reward=self.total_reward,
|
333
|
+
terminated=self.terminated,
|
334
|
+
truncated=False,
|
335
|
+
)
|
336
|
+
|
337
|
+
return private_state, public_state
|
338
|
+
|
339
|
+
|
340
|
+
class SynthTicTacToeObservationCallable(GetObservationCallable):
|
341
|
+
async def get_observation(
|
342
|
+
self, pub: TicTacToePublicState, priv: TicTacToePrivateState
|
343
|
+
) -> InternalObservation:
|
344
|
+
observation: InternalObservation = {
|
345
|
+
"board_text": pub.board_text,
|
346
|
+
"current_player": pub.current_player,
|
347
|
+
"move_count": pub.move_count,
|
348
|
+
"last_move": pub.last_move,
|
349
|
+
"winner": pub.winner,
|
350
|
+
"terminated": pub.terminated,
|
351
|
+
"reward_last": priv.reward_last,
|
352
|
+
"total_reward": priv.total_reward,
|
353
|
+
}
|
354
|
+
return observation
|
355
|
+
|
356
|
+
|
357
|
+
class SynthTicTacToeCheckpointObservationCallable(GetObservationCallable):
|
358
|
+
async def get_observation(
|
359
|
+
self, pub: TicTacToePublicState, priv: TicTacToePrivateState
|
360
|
+
) -> InternalObservation:
|
361
|
+
observation: InternalObservation = {
|
362
|
+
"board_text_final": pub.board_text,
|
363
|
+
"winner_final": pub.winner,
|
364
|
+
"move_count_final": pub.move_count,
|
365
|
+
"total_reward": priv.total_reward,
|
366
|
+
"terminated": pub.terminated,
|
367
|
+
}
|
368
|
+
return observation
|
@@ -0,0 +1,239 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Dict, Optional, Any, List, Union
|
4
|
+
from pydantic import BaseModel
|
5
|
+
|
6
|
+
from synth_ai.environments.stateful.core import StatefulEnvironment
|
7
|
+
from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
|
8
|
+
from synth_ai.environments.environment.shared_engine import (
|
9
|
+
GetObservationCallable,
|
10
|
+
InternalObservation,
|
11
|
+
)
|
12
|
+
from synth_ai.environments.environment.tools import (
|
13
|
+
AbstractTool,
|
14
|
+
EnvToolCall,
|
15
|
+
ToolResult,
|
16
|
+
)
|
17
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
18
|
+
|
19
|
+
from .engine import (
|
20
|
+
TicTacToeEngine,
|
21
|
+
TicTacToePublicState,
|
22
|
+
TicTacToePrivateState,
|
23
|
+
TicTacToeEngineSnapshot,
|
24
|
+
SynthTicTacToeObservationCallable,
|
25
|
+
SynthTicTacToeCheckpointObservationCallable,
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class TicTacToeActionInput(BaseModel):
|
30
|
+
letter: str # "A", "B", or "C"
|
31
|
+
number: int # 1, 2, or 3
|
32
|
+
|
33
|
+
|
34
|
+
class TicTacToeInteractTool(AbstractTool):
|
35
|
+
name = "interact"
|
36
|
+
description = "Place your mark (X or O) in the specified cell using letter (A, B, C) and number (1, 2, 3) coordinates."
|
37
|
+
call_schema = TicTacToeActionInput
|
38
|
+
result_schema = ToolResult
|
39
|
+
|
40
|
+
def __init__(self, engine: TicTacToeEngine):
|
41
|
+
self.engine = engine
|
42
|
+
|
43
|
+
async def __call__(self, call: EnvToolCall) -> ToolResult:
|
44
|
+
try:
|
45
|
+
# Parse input - now using separate letter and number parameters
|
46
|
+
letter = call.args.get("letter")
|
47
|
+
number = call.args.get("number")
|
48
|
+
|
49
|
+
if not letter or number is None:
|
50
|
+
return ToolResult(
|
51
|
+
ok=False, error="Both letter and number parameters are required", payload={}
|
52
|
+
)
|
53
|
+
|
54
|
+
# Validate letter
|
55
|
+
if letter not in ["A", "B", "C"]:
|
56
|
+
return ToolResult(
|
57
|
+
ok=False, error=f"Letter must be A, B, or C, got '{letter}'", payload={}
|
58
|
+
)
|
59
|
+
|
60
|
+
# Validate number
|
61
|
+
if number not in [1, 2, 3]:
|
62
|
+
return ToolResult(
|
63
|
+
ok=False, error=f"Number must be 1, 2, or 3, got {number}", payload={}
|
64
|
+
)
|
65
|
+
|
66
|
+
# Convert to coordinate string (e.g., "A1", "B2", etc.)
|
67
|
+
action = f"{letter}{number}"
|
68
|
+
|
69
|
+
# Execute action
|
70
|
+
private_state, public_state = await self.engine._step_engine(action)
|
71
|
+
|
72
|
+
return ToolResult(
|
73
|
+
ok=True,
|
74
|
+
payload={"public_state": public_state, "private_state": private_state},
|
75
|
+
)
|
76
|
+
except Exception as e:
|
77
|
+
return ToolResult(ok=False, error=str(e), payload={})
|
78
|
+
|
79
|
+
|
80
|
+
class TicTacToeEnvironment(StatefulEnvironment, ReproducibleEnvironment[TicTacToeEngine]):
|
81
|
+
def __init__(
|
82
|
+
self,
|
83
|
+
task_instance: TaskInstance,
|
84
|
+
custom_step_obs: Optional[GetObservationCallable] = None,
|
85
|
+
custom_ckpt_obs: Optional[GetObservationCallable] = None,
|
86
|
+
):
|
87
|
+
self.name = "TicTacToe"
|
88
|
+
self.task_instance = task_instance
|
89
|
+
self.custom_step_observation_callable = (
|
90
|
+
custom_step_obs or SynthTicTacToeObservationCallable()
|
91
|
+
)
|
92
|
+
self.custom_checkpoint_observation_callable = (
|
93
|
+
custom_ckpt_obs or SynthTicTacToeCheckpointObservationCallable()
|
94
|
+
)
|
95
|
+
self.engine = TicTacToeEngine(task_instance)
|
96
|
+
self._interact_tool = TicTacToeInteractTool(self.engine)
|
97
|
+
|
98
|
+
async def initialize(self) -> InternalObservation:
|
99
|
+
# Reset engine and return initial observation
|
100
|
+
priv, pub = await self.engine._reset_engine()
|
101
|
+
return await self._to_observation(priv, pub, self.custom_step_observation_callable)
|
102
|
+
|
103
|
+
async def step(self, tool_calls) -> InternalObservation:
|
104
|
+
# Validate and normalize tool calls
|
105
|
+
validated_call = self.validate_tool_calls(tool_calls)
|
106
|
+
|
107
|
+
# Execute the interact tool
|
108
|
+
result = await self._interact_tool(validated_call)
|
109
|
+
|
110
|
+
if result.ok:
|
111
|
+
priv = result.payload["private_state"]
|
112
|
+
pub = result.payload["public_state"]
|
113
|
+
return await self._to_observation(priv, pub, self.custom_step_observation_callable)
|
114
|
+
else:
|
115
|
+
# Return error observation
|
116
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
117
|
+
return await self._to_observation(
|
118
|
+
priv,
|
119
|
+
pub,
|
120
|
+
self.custom_step_observation_callable,
|
121
|
+
extra_obs={"error": result.error},
|
122
|
+
)
|
123
|
+
|
124
|
+
async def checkpoint(self) -> InternalObservation:
|
125
|
+
# Return checkpoint observation
|
126
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
127
|
+
return await self._to_observation(priv, pub, self.custom_checkpoint_observation_callable)
|
128
|
+
|
129
|
+
async def terminate(self) -> InternalObservation:
|
130
|
+
# Mark as terminated and return final observation
|
131
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
132
|
+
pub.terminated = True
|
133
|
+
priv.terminated = True
|
134
|
+
return await self._to_observation(priv, pub, self.custom_checkpoint_observation_callable)
|
135
|
+
|
136
|
+
def validate_tool_calls(self, tool_calls) -> EnvToolCall:
|
137
|
+
# Handle various input formats
|
138
|
+
if isinstance(tool_calls, EnvToolCall):
|
139
|
+
validated_call = tool_calls
|
140
|
+
elif isinstance(tool_calls, dict):
|
141
|
+
# Handle dict format
|
142
|
+
if "tool" in tool_calls:
|
143
|
+
validated_call = EnvToolCall(
|
144
|
+
tool=tool_calls["tool"], args=tool_calls.get("args", {})
|
145
|
+
)
|
146
|
+
elif "name" in tool_calls:
|
147
|
+
# Handle legacy format
|
148
|
+
validated_call = EnvToolCall(
|
149
|
+
tool=tool_calls["name"], args=tool_calls.get("parameters", {})
|
150
|
+
)
|
151
|
+
elif "function" in tool_calls:
|
152
|
+
# Handle OpenAI function call format
|
153
|
+
validated_call = EnvToolCall(
|
154
|
+
tool=tool_calls["function"]["name"],
|
155
|
+
args=tool_calls["function"].get("arguments", {}),
|
156
|
+
)
|
157
|
+
else:
|
158
|
+
# Assume it's just parameters
|
159
|
+
validated_call = EnvToolCall(tool="interact", args=tool_calls)
|
160
|
+
elif isinstance(tool_calls, list):
|
161
|
+
# Take first call from list
|
162
|
+
if len(tool_calls) > 0:
|
163
|
+
validated_call = self.validate_tool_calls(tool_calls[0])
|
164
|
+
else:
|
165
|
+
raise ValueError("Empty tool calls list")
|
166
|
+
else:
|
167
|
+
# Try to convert to dict
|
168
|
+
validated_call = EnvToolCall(tool="interact", args={"action": str(tool_calls)})
|
169
|
+
|
170
|
+
# Validate tool name
|
171
|
+
if validated_call.tool != "interact":
|
172
|
+
raise ValueError(f"Unknown tool: {validated_call.tool}")
|
173
|
+
|
174
|
+
# Convert legacy formats to new letter/number format
|
175
|
+
args = validated_call.args
|
176
|
+
if "position" in args:
|
177
|
+
# Convert numeric position (0-8) to letter/number
|
178
|
+
position = args["position"]
|
179
|
+
if position < 0 or position > 8:
|
180
|
+
raise ValueError(f"Position {position} must be between 0 and 8")
|
181
|
+
letter = ["A", "B", "C"][position // 3]
|
182
|
+
number = (position % 3) + 1
|
183
|
+
args = {"letter": letter, "number": number}
|
184
|
+
elif "action" in args:
|
185
|
+
# Convert coordinate string (e.g., "A1") to letter/number
|
186
|
+
action = args["action"]
|
187
|
+
if len(action) != 2:
|
188
|
+
raise ValueError(f"Action '{action}' must be 2 characters (e.g., 'A1')")
|
189
|
+
letter = action[0].upper()
|
190
|
+
try:
|
191
|
+
number = int(action[1])
|
192
|
+
except ValueError:
|
193
|
+
raise ValueError(f"Action '{action}' must have a numeric second character")
|
194
|
+
args = {"letter": letter, "number": number}
|
195
|
+
|
196
|
+
# Validate final letter/number values
|
197
|
+
if "letter" in args and "number" in args:
|
198
|
+
letter = args["letter"]
|
199
|
+
number = args["number"]
|
200
|
+
if letter not in ["A", "B", "C"]:
|
201
|
+
raise ValueError(f"Letter must be A, B, or C, got '{letter}'")
|
202
|
+
if number not in [1, 2, 3]:
|
203
|
+
raise ValueError(f"Number must be 1, 2, or 3, got {number}")
|
204
|
+
|
205
|
+
return EnvToolCall(tool=validated_call.tool, args=args)
|
206
|
+
|
207
|
+
async def _to_observation(
|
208
|
+
self,
|
209
|
+
priv: TicTacToePrivateState,
|
210
|
+
pub: TicTacToePublicState,
|
211
|
+
obs_cb: Optional[GetObservationCallable],
|
212
|
+
extra_obs: Optional[Dict] = None,
|
213
|
+
) -> InternalObservation:
|
214
|
+
# Convert states to observation using callback
|
215
|
+
if obs_cb:
|
216
|
+
obs = await obs_cb.get_observation(pub, priv)
|
217
|
+
else:
|
218
|
+
obs: InternalObservation = {}
|
219
|
+
|
220
|
+
if extra_obs and isinstance(obs, dict):
|
221
|
+
obs.update(extra_obs)
|
222
|
+
|
223
|
+
return obs
|
224
|
+
|
225
|
+
async def _serialize_engine(self) -> TicTacToeEngineSnapshot:
|
226
|
+
# Delegate to engine serialization
|
227
|
+
return await self.engine._serialize_engine()
|
228
|
+
|
229
|
+
@classmethod
|
230
|
+
async def _deserialize_engine(
|
231
|
+
cls, snapshot: TicTacToeEngineSnapshot, task_instance: TaskInstance
|
232
|
+
) -> "TicTacToeEnvironment":
|
233
|
+
# Create new environment instance
|
234
|
+
env = cls(task_instance)
|
235
|
+
# Restore engine from snapshot
|
236
|
+
env.engine = await TicTacToeEngine._deserialize_engine(snapshot)
|
237
|
+
# Update tool reference
|
238
|
+
env._interact_tool = TicTacToeInteractTool(env.engine)
|
239
|
+
return env
|