synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. synth_ai/environments/examples/__init__.py +1 -0
  2. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  3. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  4. synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
  5. synth_ai/environments/examples/crafter_classic/engine.py +579 -0
  6. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  7. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  8. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  9. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
  10. synth_ai/environments/examples/crafter_classic/environment.py +364 -0
  11. synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
  12. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
  13. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
  14. synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
  15. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
  16. synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
  17. synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
  18. synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
  19. synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
  20. synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
  21. synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
  22. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
  23. synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
  24. synth_ai/environments/examples/crafter_custom/environment.py +312 -0
  25. synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
  26. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  27. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  28. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  29. synth_ai/environments/examples/enron/engine.py +291 -0
  30. synth_ai/environments/examples/enron/environment.py +165 -0
  31. synth_ai/environments/examples/enron/taskset.py +112 -0
  32. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  33. synth_ai/environments/examples/minigrid/engine.py +589 -0
  34. synth_ai/environments/examples/minigrid/environment.py +274 -0
  35. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  36. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  37. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  38. synth_ai/environments/examples/nethack/__init__.py +7 -0
  39. synth_ai/environments/examples/nethack/achievements.py +337 -0
  40. synth_ai/environments/examples/nethack/engine.py +738 -0
  41. synth_ai/environments/examples/nethack/environment.py +255 -0
  42. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  43. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  44. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  45. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  46. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  47. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  48. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  49. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  50. synth_ai/environments/examples/nethack/taskset.py +323 -0
  51. synth_ai/environments/examples/red/__init__.py +7 -0
  52. synth_ai/environments/examples/red/config_logging.py +110 -0
  53. synth_ai/environments/examples/red/engine.py +693 -0
  54. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  55. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  56. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  57. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  58. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  59. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  60. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  61. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  62. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  63. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  64. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  69. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  70. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  71. synth_ai/environments/examples/red/environment.py +235 -0
  72. synth_ai/environments/examples/red/taskset.py +77 -0
  73. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  74. synth_ai/environments/examples/sokoban/engine.py +675 -0
  75. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  76. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  77. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  78. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  79. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  80. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  81. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  82. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  87. synth_ai/environments/examples/sokoban/environment.py +228 -0
  88. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  89. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  90. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  91. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  92. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  93. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  94. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  95. synth_ai/environments/examples/verilog/__init__.py +10 -0
  96. synth_ai/environments/examples/verilog/engine.py +328 -0
  97. synth_ai/environments/examples/verilog/environment.py +349 -0
  98. synth_ai/environments/examples/verilog/taskset.py +418 -0
  99. synth_ai/environments/examples/wordle/__init__.py +29 -0
  100. synth_ai/environments/examples/wordle/engine.py +391 -0
  101. synth_ai/environments/examples/wordle/environment.py +154 -0
  102. synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
  103. synth_ai/environments/examples/wordle/taskset.py +222 -0
  104. synth_ai/environments/service/app.py +8 -0
  105. synth_ai/environments/service/core_routes.py +38 -0
  106. synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
  107. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
  108. synth_ai/learning/prompts/mipro.py +273 -1
  109. synth_ai/learning/prompts/random_search.py +247 -0
  110. synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
  111. synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
  112. synth_ai/lm/injection.py +81 -0
  113. synth_ai/lm/overrides.py +204 -0
  114. synth_ai/lm/provider_support/anthropic.py +39 -12
  115. synth_ai/lm/provider_support/openai.py +31 -4
  116. synth_ai/lm/vendors/core/anthropic_api.py +16 -0
  117. synth_ai/lm/vendors/openai_standard.py +35 -5
  118. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
  119. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
  120. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
  121. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
  122. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
  123. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,368 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Any, Optional, Tuple
6
+
7
+ from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
8
+ from synth_ai.environments.reproducibility.core import IReproducibleEngine
9
+ from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent
10
+ from synth_ai.environments.environment.shared_engine import (
11
+ GetObservationCallable,
12
+ InternalObservation,
13
+ )
14
+ from synth_ai.environments.tasks.core import TaskInstance
15
+
16
+
17
+ # Action mapping: coordinate strings to board indices
18
+ COORD_TO_IDX = {
19
+ "A1": 0,
20
+ "A2": 1,
21
+ "A3": 2,
22
+ "B1": 3,
23
+ "B2": 4,
24
+ "B3": 5,
25
+ "C1": 6,
26
+ "C2": 7,
27
+ "C3": 8,
28
+ }
29
+ IDX_TO_COORD = {v: k for k, v in COORD_TO_IDX.items()}
30
+
31
+ # Win condition patterns (row, col, diagonal indices)
32
+ WIN_PATTERNS = [
33
+ [0, 1, 2],
34
+ [3, 4, 5],
35
+ [6, 7, 8], # rows
36
+ [0, 3, 6],
37
+ [1, 4, 7],
38
+ [2, 5, 8], # columns
39
+ [0, 4, 8],
40
+ [2, 4, 6], # diagonals
41
+ ]
42
+
43
+ # Player mappings
44
+ PLAYER_MARKS = {"X": 1, "O": 2}
45
+ MARK_TO_PLAYER = {1: "X", 2: "O", 0: " "}
46
+
47
+
48
+ @dataclass
49
+ class TicTacToePublicState:
50
+ board: np.ndarray # 3x3 array: 0=empty, 1=X, 2=O
51
+ current_player: str # "X" or "O"
52
+ last_move: Optional[str] # "A1", "B2", etc.
53
+ winner: Optional[str] # None, "X", "O", or "draw"
54
+ move_count: int # Number of moves made
55
+ max_moves: int # Always 9 for TicTacToe
56
+ terminated: bool # Game finished
57
+
58
+ def diff(self, prev_state: "TicTacToePublicState") -> Dict[str, Any]:
59
+ differences = {}
60
+ if not np.array_equal(self.board, prev_state.board):
61
+ differences["board"] = self.board.tolist()
62
+ if self.current_player != prev_state.current_player:
63
+ differences["current_player"] = self.current_player
64
+ if self.last_move != prev_state.last_move:
65
+ differences["last_move"] = self.last_move
66
+ if self.winner != prev_state.winner:
67
+ differences["winner"] = self.winner
68
+ if self.move_count != prev_state.move_count:
69
+ differences["move_count"] = self.move_count
70
+ if self.terminated != prev_state.terminated:
71
+ differences["terminated"] = self.terminated
72
+ return differences
73
+
74
+ @property
75
+ def board_text(self) -> str:
76
+ lines = []
77
+ lines.append(" A B C")
78
+ for i in range(3):
79
+ row_marks = []
80
+ for j in range(3):
81
+ mark = MARK_TO_PLAYER[self.board[i * 3 + j]]
82
+ row_marks.append(mark)
83
+ lines.append(f"{i + 1} {' '.join(row_marks)}")
84
+ return "\n".join(lines)
85
+
86
+
87
+ @dataclass
88
+ class TicTacToePrivateState:
89
+ reward_last: float
90
+ total_reward: float
91
+ terminated: bool
92
+ truncated: bool
93
+
94
+ def diff(self, prev_state: "TicTacToePrivateState") -> Dict[str, Any]:
95
+ differences = {}
96
+ if self.reward_last != prev_state.reward_last:
97
+ differences["reward_last"] = self.reward_last
98
+ if self.total_reward != prev_state.total_reward:
99
+ differences["total_reward"] = self.total_reward
100
+ if self.terminated != prev_state.terminated:
101
+ differences["terminated"] = self.terminated
102
+ if self.truncated != prev_state.truncated:
103
+ differences["truncated"] = self.truncated
104
+ return differences
105
+
106
+
107
+ @dataclass
108
+ class TicTacToeEngineSnapshot(StatefulEngineSnapshot):
109
+ task_instance_dict: Dict
110
+ engine_snapshot: Dict
111
+
112
+
113
+ class TicTacToeWinComponent(RewardComponent):
114
+ def __init__(self, player_mark: str = "X"):
115
+ super().__init__()
116
+ self.player_mark = player_mark
117
+
118
+ async def score(self, state: TicTacToePublicState, action: Any) -> float:
119
+ if state.winner == self.player_mark:
120
+ return 1.0
121
+ elif state.winner and state.winner != "draw":
122
+ return -1.0 # Opponent won
123
+ return 0.0
124
+
125
+
126
+ class TicTacToeDrawComponent(RewardComponent):
127
+ async def score(self, state: TicTacToePublicState, action: Any) -> float:
128
+ if state.winner == "draw":
129
+ return 0.0
130
+ return 0.0
131
+
132
+
133
+ class TicTacToeIllegalMoveComponent(RewardComponent):
134
+ def __init__(self):
135
+ self.illegal_move_attempted = False
136
+
137
+ async def score(self, state: TicTacToePublicState, action: Any) -> float:
138
+ if self.illegal_move_attempted:
139
+ self.illegal_move_attempted = False
140
+ return -1.0
141
+ return 0.0
142
+
143
+
144
+ class TicTacToeEngine(StatefulEngine, IReproducibleEngine):
145
+ def __init__(self, task_instance: TaskInstance):
146
+ self.task_instance = task_instance
147
+ self.illegal_move_component = TicTacToeIllegalMoveComponent()
148
+
149
+ # Determine which player the agent is controlling
150
+ agent_player = "X" # Default to X
151
+ if hasattr(task_instance, "metadata") and hasattr(
152
+ task_instance.metadata, "starting_player"
153
+ ):
154
+ agent_player = task_instance.metadata.starting_player
155
+
156
+ self.reward_stack = RewardStack(
157
+ [
158
+ TicTacToeWinComponent(player_mark=agent_player),
159
+ TicTacToeDrawComponent(),
160
+ self.illegal_move_component,
161
+ ]
162
+ )
163
+
164
+ # Initialize game state
165
+ self.board = np.zeros(9, dtype=int)
166
+ self.current_player = "X"
167
+ self.last_move = None
168
+ self.winner = None
169
+ self.move_count = 0
170
+ self.terminated = False
171
+ self.total_reward = 0.0
172
+
173
+ # Apply any pre-moves from task instance metadata
174
+ if hasattr(task_instance, "metadata") and hasattr(task_instance.metadata, "opening_moves"):
175
+ for move in task_instance.metadata.opening_moves:
176
+ self._apply_move(move)
177
+
178
+ async def _reset_engine(
179
+ self, *, seed: int | None = None
180
+ ) -> Tuple[TicTacToePrivateState, TicTacToePublicState]:
181
+ self.board = np.zeros(9, dtype=int)
182
+ self.current_player = "X"
183
+ self.last_move = None
184
+ self.winner = None
185
+ self.move_count = 0
186
+ self.terminated = False
187
+ self.total_reward = 0.0
188
+
189
+ # Apply any pre-moves from task instance metadata
190
+ if hasattr(self.task_instance, "metadata") and hasattr(
191
+ self.task_instance.metadata, "opening_moves"
192
+ ):
193
+ for move in self.task_instance.metadata.opening_moves:
194
+ self._apply_move(move)
195
+
196
+ public_state = TicTacToePublicState(
197
+ board=self.board.copy(),
198
+ current_player=self.current_player,
199
+ last_move=self.last_move,
200
+ winner=self.winner,
201
+ move_count=self.move_count,
202
+ max_moves=9,
203
+ terminated=self.terminated,
204
+ )
205
+
206
+ private_state = TicTacToePrivateState(
207
+ reward_last=0.0,
208
+ total_reward=self.total_reward,
209
+ terminated=self.terminated,
210
+ truncated=False,
211
+ )
212
+
213
+ return private_state, public_state
214
+
215
+ async def _step_engine(self, action: str) -> Tuple[TicTacToePrivateState, TicTacToePublicState]:
216
+ # Validate and apply move
217
+ if not self._is_valid_move(action, self.board):
218
+ self.illegal_move_component.illegal_move_attempted = True
219
+ self.terminated = True
220
+ else:
221
+ self._apply_move(action)
222
+
223
+ # Create public state
224
+ public_state = TicTacToePublicState(
225
+ board=self.board.copy(),
226
+ current_player=self.current_player,
227
+ last_move=self.last_move,
228
+ winner=self.winner,
229
+ move_count=self.move_count,
230
+ max_moves=9,
231
+ terminated=self.terminated,
232
+ )
233
+
234
+ # Calculate rewards
235
+ reward = await self.reward_stack.step_reward(public_state, action)
236
+ self.total_reward += reward
237
+
238
+ # Create private state
239
+ private_state = TicTacToePrivateState(
240
+ reward_last=reward,
241
+ total_reward=self.total_reward,
242
+ terminated=self.terminated,
243
+ truncated=False,
244
+ )
245
+
246
+ return private_state, public_state
247
+
248
+ def _apply_move(self, coord: str):
249
+ if coord not in COORD_TO_IDX:
250
+ return
251
+
252
+ idx = COORD_TO_IDX[coord]
253
+ if self.board[idx] == 0:
254
+ self.board[idx] = PLAYER_MARKS[self.current_player]
255
+ self.last_move = coord
256
+ self.move_count += 1
257
+
258
+ # Check for winner
259
+ self.winner = self._check_winner(self.board)
260
+
261
+ # Check if game is over
262
+ if self.winner is not None or self.move_count >= 9:
263
+ self.terminated = True
264
+ else:
265
+ # Switch players
266
+ self.current_player = "O" if self.current_player == "X" else "X"
267
+
268
+ def _check_winner(self, board: np.ndarray) -> Optional[str]:
269
+ # Check all win patterns
270
+ for pattern in WIN_PATTERNS:
271
+ values = [board[i] for i in pattern]
272
+ if values[0] != 0 and values[0] == values[1] == values[2]:
273
+ return MARK_TO_PLAYER[values[0]]
274
+
275
+ # Check for draw
276
+ if np.all(board != 0):
277
+ return "draw"
278
+
279
+ return None
280
+
281
+ def _is_valid_move(self, coord: str, board: np.ndarray) -> bool:
282
+ if coord not in COORD_TO_IDX:
283
+ return False
284
+ idx = COORD_TO_IDX[coord]
285
+ return board[idx] == 0
286
+
287
+ async def _serialize_engine(self) -> TicTacToeEngineSnapshot:
288
+ return TicTacToeEngineSnapshot(
289
+ task_instance_dict=await self.task_instance.serialize(),
290
+ engine_snapshot={
291
+ "board": self.board.tolist(),
292
+ "current_player": self.current_player,
293
+ "last_move": self.last_move,
294
+ "winner": self.winner,
295
+ "move_count": self.move_count,
296
+ "terminated": self.terminated,
297
+ "total_reward": self.total_reward,
298
+ },
299
+ )
300
+
301
+ @classmethod
302
+ async def _deserialize_engine(cls, snapshot: TicTacToeEngineSnapshot) -> "TicTacToeEngine":
303
+ task_instance = await TaskInstance.deserialize(snapshot.task_instance_dict)
304
+ engine = cls(task_instance)
305
+
306
+ # Restore state
307
+ engine.board = np.array(snapshot.engine_snapshot["board"])
308
+ engine.current_player = snapshot.engine_snapshot["current_player"]
309
+ engine.last_move = snapshot.engine_snapshot["last_move"]
310
+ engine.winner = snapshot.engine_snapshot["winner"]
311
+ engine.move_count = snapshot.engine_snapshot["move_count"]
312
+ engine.terminated = snapshot.engine_snapshot["terminated"]
313
+ engine.total_reward = snapshot.engine_snapshot["total_reward"]
314
+
315
+ return engine
316
+
317
+ def get_current_states_for_observation(
318
+ self,
319
+ ) -> Tuple[TicTacToePrivateState, TicTacToePublicState]:
320
+ public_state = TicTacToePublicState(
321
+ board=self.board.copy(),
322
+ current_player=self.current_player,
323
+ last_move=self.last_move,
324
+ winner=self.winner,
325
+ move_count=self.move_count,
326
+ max_moves=9,
327
+ terminated=self.terminated,
328
+ )
329
+
330
+ private_state = TicTacToePrivateState(
331
+ reward_last=0.0,
332
+ total_reward=self.total_reward,
333
+ terminated=self.terminated,
334
+ truncated=False,
335
+ )
336
+
337
+ return private_state, public_state
338
+
339
+
340
+ class SynthTicTacToeObservationCallable(GetObservationCallable):
341
+ async def get_observation(
342
+ self, pub: TicTacToePublicState, priv: TicTacToePrivateState
343
+ ) -> InternalObservation:
344
+ observation: InternalObservation = {
345
+ "board_text": pub.board_text,
346
+ "current_player": pub.current_player,
347
+ "move_count": pub.move_count,
348
+ "last_move": pub.last_move,
349
+ "winner": pub.winner,
350
+ "terminated": pub.terminated,
351
+ "reward_last": priv.reward_last,
352
+ "total_reward": priv.total_reward,
353
+ }
354
+ return observation
355
+
356
+
357
+ class SynthTicTacToeCheckpointObservationCallable(GetObservationCallable):
358
+ async def get_observation(
359
+ self, pub: TicTacToePublicState, priv: TicTacToePrivateState
360
+ ) -> InternalObservation:
361
+ observation: InternalObservation = {
362
+ "board_text_final": pub.board_text,
363
+ "winner_final": pub.winner,
364
+ "move_count_final": pub.move_count,
365
+ "total_reward": priv.total_reward,
366
+ "terminated": pub.terminated,
367
+ }
368
+ return observation
@@ -0,0 +1,239 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, Optional, Any, List, Union
4
+ from pydantic import BaseModel
5
+
6
+ from synth_ai.environments.stateful.core import StatefulEnvironment
7
+ from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
8
+ from synth_ai.environments.environment.shared_engine import (
9
+ GetObservationCallable,
10
+ InternalObservation,
11
+ )
12
+ from synth_ai.environments.environment.tools import (
13
+ AbstractTool,
14
+ EnvToolCall,
15
+ ToolResult,
16
+ )
17
+ from synth_ai.environments.tasks.core import TaskInstance
18
+
19
+ from .engine import (
20
+ TicTacToeEngine,
21
+ TicTacToePublicState,
22
+ TicTacToePrivateState,
23
+ TicTacToeEngineSnapshot,
24
+ SynthTicTacToeObservationCallable,
25
+ SynthTicTacToeCheckpointObservationCallable,
26
+ )
27
+
28
+
29
+ class TicTacToeActionInput(BaseModel):
30
+ letter: str # "A", "B", or "C"
31
+ number: int # 1, 2, or 3
32
+
33
+
34
+ class TicTacToeInteractTool(AbstractTool):
35
+ name = "interact"
36
+ description = "Place your mark (X or O) in the specified cell using letter (A, B, C) and number (1, 2, 3) coordinates."
37
+ call_schema = TicTacToeActionInput
38
+ result_schema = ToolResult
39
+
40
+ def __init__(self, engine: TicTacToeEngine):
41
+ self.engine = engine
42
+
43
+ async def __call__(self, call: EnvToolCall) -> ToolResult:
44
+ try:
45
+ # Parse input - now using separate letter and number parameters
46
+ letter = call.args.get("letter")
47
+ number = call.args.get("number")
48
+
49
+ if not letter or number is None:
50
+ return ToolResult(
51
+ ok=False, error="Both letter and number parameters are required", payload={}
52
+ )
53
+
54
+ # Validate letter
55
+ if letter not in ["A", "B", "C"]:
56
+ return ToolResult(
57
+ ok=False, error=f"Letter must be A, B, or C, got '{letter}'", payload={}
58
+ )
59
+
60
+ # Validate number
61
+ if number not in [1, 2, 3]:
62
+ return ToolResult(
63
+ ok=False, error=f"Number must be 1, 2, or 3, got {number}", payload={}
64
+ )
65
+
66
+ # Convert to coordinate string (e.g., "A1", "B2", etc.)
67
+ action = f"{letter}{number}"
68
+
69
+ # Execute action
70
+ private_state, public_state = await self.engine._step_engine(action)
71
+
72
+ return ToolResult(
73
+ ok=True,
74
+ payload={"public_state": public_state, "private_state": private_state},
75
+ )
76
+ except Exception as e:
77
+ return ToolResult(ok=False, error=str(e), payload={})
78
+
79
+
80
+ class TicTacToeEnvironment(StatefulEnvironment, ReproducibleEnvironment[TicTacToeEngine]):
81
+ def __init__(
82
+ self,
83
+ task_instance: TaskInstance,
84
+ custom_step_obs: Optional[GetObservationCallable] = None,
85
+ custom_ckpt_obs: Optional[GetObservationCallable] = None,
86
+ ):
87
+ self.name = "TicTacToe"
88
+ self.task_instance = task_instance
89
+ self.custom_step_observation_callable = (
90
+ custom_step_obs or SynthTicTacToeObservationCallable()
91
+ )
92
+ self.custom_checkpoint_observation_callable = (
93
+ custom_ckpt_obs or SynthTicTacToeCheckpointObservationCallable()
94
+ )
95
+ self.engine = TicTacToeEngine(task_instance)
96
+ self._interact_tool = TicTacToeInteractTool(self.engine)
97
+
98
+ async def initialize(self) -> InternalObservation:
99
+ # Reset engine and return initial observation
100
+ priv, pub = await self.engine._reset_engine()
101
+ return await self._to_observation(priv, pub, self.custom_step_observation_callable)
102
+
103
+ async def step(self, tool_calls) -> InternalObservation:
104
+ # Validate and normalize tool calls
105
+ validated_call = self.validate_tool_calls(tool_calls)
106
+
107
+ # Execute the interact tool
108
+ result = await self._interact_tool(validated_call)
109
+
110
+ if result.ok:
111
+ priv = result.payload["private_state"]
112
+ pub = result.payload["public_state"]
113
+ return await self._to_observation(priv, pub, self.custom_step_observation_callable)
114
+ else:
115
+ # Return error observation
116
+ priv, pub = self.engine.get_current_states_for_observation()
117
+ return await self._to_observation(
118
+ priv,
119
+ pub,
120
+ self.custom_step_observation_callable,
121
+ extra_obs={"error": result.error},
122
+ )
123
+
124
+ async def checkpoint(self) -> InternalObservation:
125
+ # Return checkpoint observation
126
+ priv, pub = self.engine.get_current_states_for_observation()
127
+ return await self._to_observation(priv, pub, self.custom_checkpoint_observation_callable)
128
+
129
+ async def terminate(self) -> InternalObservation:
130
+ # Mark as terminated and return final observation
131
+ priv, pub = self.engine.get_current_states_for_observation()
132
+ pub.terminated = True
133
+ priv.terminated = True
134
+ return await self._to_observation(priv, pub, self.custom_checkpoint_observation_callable)
135
+
136
+ def validate_tool_calls(self, tool_calls) -> EnvToolCall:
137
+ # Handle various input formats
138
+ if isinstance(tool_calls, EnvToolCall):
139
+ validated_call = tool_calls
140
+ elif isinstance(tool_calls, dict):
141
+ # Handle dict format
142
+ if "tool" in tool_calls:
143
+ validated_call = EnvToolCall(
144
+ tool=tool_calls["tool"], args=tool_calls.get("args", {})
145
+ )
146
+ elif "name" in tool_calls:
147
+ # Handle legacy format
148
+ validated_call = EnvToolCall(
149
+ tool=tool_calls["name"], args=tool_calls.get("parameters", {})
150
+ )
151
+ elif "function" in tool_calls:
152
+ # Handle OpenAI function call format
153
+ validated_call = EnvToolCall(
154
+ tool=tool_calls["function"]["name"],
155
+ args=tool_calls["function"].get("arguments", {}),
156
+ )
157
+ else:
158
+ # Assume it's just parameters
159
+ validated_call = EnvToolCall(tool="interact", args=tool_calls)
160
+ elif isinstance(tool_calls, list):
161
+ # Take first call from list
162
+ if len(tool_calls) > 0:
163
+ validated_call = self.validate_tool_calls(tool_calls[0])
164
+ else:
165
+ raise ValueError("Empty tool calls list")
166
+ else:
167
+ # Try to convert to dict
168
+ validated_call = EnvToolCall(tool="interact", args={"action": str(tool_calls)})
169
+
170
+ # Validate tool name
171
+ if validated_call.tool != "interact":
172
+ raise ValueError(f"Unknown tool: {validated_call.tool}")
173
+
174
+ # Convert legacy formats to new letter/number format
175
+ args = validated_call.args
176
+ if "position" in args:
177
+ # Convert numeric position (0-8) to letter/number
178
+ position = args["position"]
179
+ if position < 0 or position > 8:
180
+ raise ValueError(f"Position {position} must be between 0 and 8")
181
+ letter = ["A", "B", "C"][position // 3]
182
+ number = (position % 3) + 1
183
+ args = {"letter": letter, "number": number}
184
+ elif "action" in args:
185
+ # Convert coordinate string (e.g., "A1") to letter/number
186
+ action = args["action"]
187
+ if len(action) != 2:
188
+ raise ValueError(f"Action '{action}' must be 2 characters (e.g., 'A1')")
189
+ letter = action[0].upper()
190
+ try:
191
+ number = int(action[1])
192
+ except ValueError:
193
+ raise ValueError(f"Action '{action}' must have a numeric second character")
194
+ args = {"letter": letter, "number": number}
195
+
196
+ # Validate final letter/number values
197
+ if "letter" in args and "number" in args:
198
+ letter = args["letter"]
199
+ number = args["number"]
200
+ if letter not in ["A", "B", "C"]:
201
+ raise ValueError(f"Letter must be A, B, or C, got '{letter}'")
202
+ if number not in [1, 2, 3]:
203
+ raise ValueError(f"Number must be 1, 2, or 3, got {number}")
204
+
205
+ return EnvToolCall(tool=validated_call.tool, args=args)
206
+
207
+ async def _to_observation(
208
+ self,
209
+ priv: TicTacToePrivateState,
210
+ pub: TicTacToePublicState,
211
+ obs_cb: Optional[GetObservationCallable],
212
+ extra_obs: Optional[Dict] = None,
213
+ ) -> InternalObservation:
214
+ # Convert states to observation using callback
215
+ if obs_cb:
216
+ obs = await obs_cb.get_observation(pub, priv)
217
+ else:
218
+ obs: InternalObservation = {}
219
+
220
+ if extra_obs and isinstance(obs, dict):
221
+ obs.update(extra_obs)
222
+
223
+ return obs
224
+
225
+ async def _serialize_engine(self) -> TicTacToeEngineSnapshot:
226
+ # Delegate to engine serialization
227
+ return await self.engine._serialize_engine()
228
+
229
+ @classmethod
230
+ async def _deserialize_engine(
231
+ cls, snapshot: TicTacToeEngineSnapshot, task_instance: TaskInstance
232
+ ) -> "TicTacToeEnvironment":
233
+ # Create new environment instance
234
+ env = cls(task_instance)
235
+ # Restore engine from snapshot
236
+ env.engine = await TicTacToeEngine._deserialize_engine(snapshot)
237
+ # Update tool reference
238
+ env._interact_tool = TicTacToeInteractTool(env.engine)
239
+ return env