synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. synth_ai/environments/examples/__init__.py +1 -0
  2. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  3. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  4. synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
  5. synth_ai/environments/examples/crafter_classic/engine.py +579 -0
  6. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  7. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  8. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  9. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
  10. synth_ai/environments/examples/crafter_classic/environment.py +364 -0
  11. synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
  12. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
  13. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
  14. synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
  15. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
  16. synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
  17. synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
  18. synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
  19. synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
  20. synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
  21. synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
  22. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
  23. synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
  24. synth_ai/environments/examples/crafter_custom/environment.py +312 -0
  25. synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
  26. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  27. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  28. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  29. synth_ai/environments/examples/enron/engine.py +291 -0
  30. synth_ai/environments/examples/enron/environment.py +165 -0
  31. synth_ai/environments/examples/enron/taskset.py +112 -0
  32. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  33. synth_ai/environments/examples/minigrid/engine.py +589 -0
  34. synth_ai/environments/examples/minigrid/environment.py +274 -0
  35. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  36. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  37. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  38. synth_ai/environments/examples/nethack/__init__.py +7 -0
  39. synth_ai/environments/examples/nethack/achievements.py +337 -0
  40. synth_ai/environments/examples/nethack/engine.py +738 -0
  41. synth_ai/environments/examples/nethack/environment.py +255 -0
  42. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  43. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  44. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  45. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  46. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  47. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  48. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  49. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  50. synth_ai/environments/examples/nethack/taskset.py +323 -0
  51. synth_ai/environments/examples/red/__init__.py +7 -0
  52. synth_ai/environments/examples/red/config_logging.py +110 -0
  53. synth_ai/environments/examples/red/engine.py +693 -0
  54. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  55. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  56. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  57. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  58. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  59. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  60. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  61. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  62. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  63. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  64. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  69. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  70. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  71. synth_ai/environments/examples/red/environment.py +235 -0
  72. synth_ai/environments/examples/red/taskset.py +77 -0
  73. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  74. synth_ai/environments/examples/sokoban/engine.py +675 -0
  75. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  76. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  77. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  78. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  79. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  80. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  81. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  82. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  87. synth_ai/environments/examples/sokoban/environment.py +228 -0
  88. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  89. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  90. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  91. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  92. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  93. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  94. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  95. synth_ai/environments/examples/verilog/__init__.py +10 -0
  96. synth_ai/environments/examples/verilog/engine.py +328 -0
  97. synth_ai/environments/examples/verilog/environment.py +349 -0
  98. synth_ai/environments/examples/verilog/taskset.py +418 -0
  99. synth_ai/environments/examples/wordle/__init__.py +29 -0
  100. synth_ai/environments/examples/wordle/engine.py +391 -0
  101. synth_ai/environments/examples/wordle/environment.py +154 -0
  102. synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
  103. synth_ai/environments/examples/wordle/taskset.py +222 -0
  104. synth_ai/environments/service/app.py +8 -0
  105. synth_ai/environments/service/core_routes.py +38 -0
  106. synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
  107. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
  108. synth_ai/learning/prompts/mipro.py +273 -1
  109. synth_ai/learning/prompts/random_search.py +247 -0
  110. synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
  111. synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
  112. synth_ai/lm/injection.py +81 -0
  113. synth_ai/lm/overrides.py +204 -0
  114. synth_ai/lm/provider_support/anthropic.py +39 -12
  115. synth_ai/lm/provider_support/openai.py +31 -4
  116. synth_ai/lm/vendors/core/anthropic_api.py +16 -0
  117. synth_ai/lm/vendors/openai_standard.py +35 -5
  118. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
  119. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
  120. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
  121. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
  122. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
  123. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,693 @@
1
+ from __future__ import annotations
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Dict, Any, Optional, List
5
+ from dataclasses import dataclass
6
+
7
+ # Import logging configuration first to suppress JAX debug messages
8
+
9
+ from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
10
+ from synth_ai.environments.reproducibility.core import IReproducibleEngine
11
+ from synth_ai.environments.environment.rewards.core import RewardStack
12
+ from synth_ai.environments.tasks.core import TaskInstance
13
+
14
+ from .engine_helpers.state_extraction import extract_game_state
15
+ from .engine_helpers.reward_components import (
16
+ BadgeRewardComponent,
17
+ MapTransitionComponent,
18
+ BattleVictoryComponent,
19
+ LevelUpComponent,
20
+ XPGainComponent,
21
+ StepPenaltyComponent,
22
+ )
23
+
24
+ try:
25
+ from pyboy import PyBoy
26
+ from pyboy.pyboy import WindowEvent
27
+
28
+ PYBOY_AVAILABLE = True
29
+ except ImportError:
30
+ PYBOY_AVAILABLE = False
31
+ PyBoy = None
32
+ WindowEvent = None
33
+
34
+ if not PYBOY_AVAILABLE:
35
+
36
+ class WindowEvent:
37
+ PRESS_BUTTON_A = 0
38
+ PRESS_BUTTON_B = 1
39
+ PRESS_ARROW_UP = 2
40
+ PRESS_ARROW_DOWN = 3
41
+ PRESS_ARROW_LEFT = 4
42
+ PRESS_ARROW_RIGHT = 5
43
+ PRESS_BUTTON_START = 6
44
+ PRESS_BUTTON_SELECT = 7
45
+ RELEASE_BUTTON_A = 8
46
+ RELEASE_BUTTON_B = 9
47
+ RELEASE_ARROW_UP = 10
48
+ RELEASE_ARROW_DOWN = 11
49
+ RELEASE_ARROW_LEFT = 12
50
+ RELEASE_ARROW_RIGHT = 13
51
+ RELEASE_BUTTON_START = 14
52
+ RELEASE_BUTTON_SELECT = 15
53
+
54
+
55
+ # Game Boy button mappings - PyBoy uses string names
56
+ BUTTON_MAP = {
57
+ "A": "a",
58
+ "B": "b",
59
+ "UP": "up",
60
+ "DOWN": "down",
61
+ "LEFT": "left",
62
+ "RIGHT": "right",
63
+ "START": "start",
64
+ "SELECT": "select",
65
+ }
66
+
67
+
68
+ @dataclass
69
+ class PokemonData:
70
+ """Detailed Pokemon information"""
71
+
72
+ species_id: int
73
+ level: int
74
+ hp_current: int
75
+ hp_max: int
76
+ xp: int
77
+ hp_percentage: float
78
+ # TODO: Add when memory addresses are available
79
+ # attack: int = 0
80
+ # defense: int = 0
81
+ # speed: int = 0
82
+ # special: int = 0
83
+ # status_conditions: List[str] = None
84
+ # moves: List[str] = None
85
+ # nickname: str = ""
86
+
87
+
88
+ @dataclass
89
+ class InventoryItem:
90
+ """Inventory item information"""
91
+
92
+ item_id: int
93
+ quantity: int
94
+ # TODO: Add when we have item name mapping
95
+ # name: str = ""
96
+ # category: str = ""
97
+
98
+
99
+ @dataclass
100
+ class GameWorldState:
101
+ """Current world/map state information"""
102
+
103
+ map_id: int
104
+ player_x: int
105
+ player_y: int
106
+ # TODO: Add when available
107
+ # map_name: str = ""
108
+ # map_type: str = "" # town, route, building, dungeon
109
+ # available_services: List[str] = None # Pokemon Center, Pokemart, Gym, etc.
110
+ # npcs_nearby: List[str] = None
111
+ # items_on_ground: List[str] = None
112
+ # wild_encounters_available: bool = False
113
+
114
+
115
+ @dataclass
116
+ class GameSystemState:
117
+ """Current game system state (menus, battles, etc.)"""
118
+
119
+ in_battle: bool
120
+ battle_outcome: int
121
+ menu_state: int
122
+ text_box_active: bool
123
+ warp_flag: int
124
+ # TODO: Add when available
125
+ # current_menu_type: str = ""
126
+ # dialogue_speaker: str = ""
127
+ # available_actions: List[str] = None
128
+
129
+
130
+ @dataclass
131
+ class PlayerProgressState:
132
+ """Player progression and achievements"""
133
+
134
+ badges: int
135
+ badge_count: int
136
+ money: int
137
+ step_count: int
138
+ # TODO: Add when available
139
+ # pokedex_seen: int = 0
140
+ # pokedex_caught: int = 0
141
+ # story_flags: List[str] = None
142
+ # time_played: str = "00:00"
143
+
144
+
145
+ @dataclass
146
+ class PokemonRedPublicState:
147
+ """Comprehensive Pokemon Red game state for text-based AI interaction
148
+
149
+ This structure provides rich, semantic game information to eliminate
150
+ the need for visual processing and enable strategic decision making.
151
+ Based on requirements from text_port.txt.
152
+ """
153
+
154
+ # Core game world state
155
+ world: GameWorldState
156
+
157
+ # Player progress and achievements
158
+ progress: PlayerProgressState
159
+
160
+ # Pokemon party information (up to 6 Pokemon)
161
+ party: List[PokemonData]
162
+
163
+ # Inventory and items
164
+ inventory: List[InventoryItem]
165
+
166
+ # Current game system state
167
+ system: GameSystemState
168
+
169
+ # Error information
170
+ error_info: Optional[str] = None
171
+
172
+ # Legacy compatibility fields (for existing code)
173
+ @property
174
+ def map_id(self) -> int:
175
+ return self.world.map_id
176
+
177
+ @property
178
+ def player_x(self) -> int:
179
+ return self.world.player_x
180
+
181
+ @property
182
+ def player_y(self) -> int:
183
+ return self.world.player_y
184
+
185
+ @property
186
+ def badges(self) -> int:
187
+ return self.progress.badges
188
+
189
+ @property
190
+ def in_battle(self) -> bool:
191
+ return self.system.in_battle
192
+
193
+ @property
194
+ def party_level(self) -> int:
195
+ return self.party[0].level if self.party else 0
196
+
197
+ @property
198
+ def party_hp_current(self) -> int:
199
+ return self.party[0].hp_current if self.party else 0
200
+
201
+ @property
202
+ def party_hp_max(self) -> int:
203
+ return self.party[0].hp_max if self.party else 0
204
+
205
+ @property
206
+ def party_xp(self) -> int:
207
+ return self.party[0].xp if self.party else 0
208
+
209
+ @property
210
+ def step_count(self) -> int:
211
+ return self.progress.step_count
212
+
213
+
214
+ @dataclass
215
+ class PokemonRedPrivateState:
216
+ reward_last_step: float
217
+ total_reward: float
218
+ terminated: bool
219
+ truncated: bool
220
+ step_count: int
221
+
222
+
223
+ class PokemonRedEngineSnapshot(StatefulEngineSnapshot):
224
+ def __init__(self, state_data: Dict[str, Any], total_reward: float, step_count: int):
225
+ self.state_data = state_data
226
+ self.total_reward = total_reward
227
+ self.step_count = step_count
228
+
229
+ def model_dump(self) -> Dict[str, Any]:
230
+ return {
231
+ "state_data": self.state_data,
232
+ "total_reward": self.total_reward,
233
+ "step_count": self.step_count,
234
+ }
235
+
236
+
237
+ class PokemonRedEngine(StatefulEngine, IReproducibleEngine):
238
+ """Pokemon Red game engine with dense reward tracking"""
239
+
240
+ def __init__(self, task_instance: TaskInstance, skip_rom_check: bool = False):
241
+ self.task_instance = task_instance
242
+
243
+ # Initialize PyBoy emulator
244
+ if not skip_rom_check:
245
+ if not PYBOY_AVAILABLE:
246
+ raise ImportError("PyBoy is required but not installed. Run: uv add pyboy")
247
+
248
+ rom_path = self._get_rom_path()
249
+ if not rom_path.exists():
250
+ raise FileNotFoundError(
251
+ f"Pokemon Red ROM not found at {rom_path}. Please see README.md for setup instructions."
252
+ )
253
+
254
+ self.emulator = PyBoy(str(rom_path), window="null")
255
+
256
+ # Load the working init state to get the game into a playable state
257
+ self._load_init_state()
258
+ else:
259
+ # For testing purposes, use None emulator
260
+ self.emulator = None
261
+
262
+ # Initialize reward stack with dense components
263
+ self.reward_stack = RewardStack(
264
+ components=[
265
+ BadgeRewardComponent(),
266
+ MapTransitionComponent(),
267
+ BattleVictoryComponent(),
268
+ LevelUpComponent(),
269
+ XPGainComponent(),
270
+ StepPenaltyComponent(),
271
+ ]
272
+ )
273
+
274
+ self._total_reward = 0.0
275
+ self._step_count = 0
276
+ self._previous_state: Optional[Dict[str, Any]] = None
277
+
278
+ def _get_rom_path(self) -> Path:
279
+ """Get path to Pokemon Red ROM file"""
280
+ # Check several possible locations
281
+ possible_paths = [
282
+ Path(__file__).parent / "roms" / "pokemon_red.gb",
283
+ Path(__file__).parent / "roms" / "PokemonRed.gb",
284
+ Path(__file__).parent / "vendor" / "pokemon_red.gb",
285
+ Path.home() / "Games" / "pokemon_red.gb",
286
+ ]
287
+
288
+ for path in possible_paths:
289
+ if path.exists():
290
+ return path
291
+
292
+ # Return default expected location
293
+ return Path(__file__).parent / "roms" / "pokemon_red.gb"
294
+
295
+ def _load_init_state(self) -> None:
296
+ """Load the initial save state to get the game into a playable state"""
297
+ init_state_paths = [
298
+ Path(__file__).parent / "roms" / "working_init.state",
299
+ Path(__file__).parent / "roms" / "init.state",
300
+ ]
301
+
302
+ for state_path in init_state_paths:
303
+ if state_path.exists():
304
+ try:
305
+ with open(state_path, "rb") as f:
306
+ self.emulator.load_state(f)
307
+ logging.info(f"Loaded init state from: {state_path}")
308
+ return
309
+ except Exception as e:
310
+ logging.warning(f"Failed to load init state from {state_path}: {e}")
311
+ continue
312
+
313
+ # If no init state found, try to use PyBoy's game wrapper
314
+ logging.warning("No init state found, trying PyBoy game wrapper...")
315
+ try:
316
+ if hasattr(self.emulator.game_wrapper, "start_game"):
317
+ self.emulator.game_wrapper.start_game()
318
+ logging.info("Used PyBoy game wrapper start_game()")
319
+ else:
320
+ logging.warning("PyBoy game wrapper doesn't have start_game method")
321
+ except Exception as e:
322
+ logging.warning(f"PyBoy game wrapper start_game failed: {e}")
323
+
324
+ def _extract_current_state(self) -> Dict[str, Any]:
325
+ """Extract current game state from emulator memory"""
326
+ if self.emulator is None:
327
+ # Return mock state for testing
328
+ return {
329
+ "map_id": 1,
330
+ "player_x": 10,
331
+ "player_y": 10,
332
+ "badges": 0,
333
+ "in_battle": False,
334
+ "party_level": 5,
335
+ "party_hp_current": 25,
336
+ "party_hp_max": 25,
337
+ "party_xp": 100,
338
+ }
339
+
340
+ # Get memory from PyBoy
341
+ memory = self.emulator.memory
342
+ return extract_game_state(memory)
343
+
344
+ def _press_button(self, button: str, frames: int = 1):
345
+ """Press a Game Boy button for specified frames"""
346
+ if button not in BUTTON_MAP:
347
+ raise ValueError(f"Invalid button: {button}. Valid buttons: {list(BUTTON_MAP.keys())}")
348
+
349
+ button_name = BUTTON_MAP[button]
350
+
351
+ if self.emulator is None:
352
+ return # Skip for testing
353
+
354
+ # Press button
355
+ self.emulator.button_press(button_name)
356
+
357
+ # Hold for specified frames
358
+ for _ in range(frames):
359
+ self.emulator.tick()
360
+
361
+ # Release button
362
+ self.emulator.button_release(button_name)
363
+
364
+ # Let release take effect
365
+ self.emulator.tick()
366
+
367
+ def _press_button_with_retry(
368
+ self, button: str, frames: int = 1, max_attempts: int = 10
369
+ ) -> bool:
370
+ """
371
+ Press a button with automatic retry for movement commands.
372
+
373
+ For movement buttons (UP, DOWN, LEFT, RIGHT), this will automatically
374
+ repeat the button press until movement occurs or max_attempts is reached.
375
+
376
+ For other buttons (A, B, START, SELECT), this behaves like _press_button.
377
+
378
+ Note: Previous menu-closing logic for 'B' button was removed because
379
+ investigation showed that menu_state memory address represents
380
+ "selected menu item index" not "menu is open", leading to false positives.
381
+
382
+ Returns True if the expected state change occurred or always True for non-retryable buttons.
383
+ """
384
+ movement_buttons = {"UP", "DOWN", "LEFT", "RIGHT"}
385
+
386
+ # Handle movement buttons with retry until position changes
387
+ if button in movement_buttons:
388
+ if self.emulator is None:
389
+ return True # Skip for testing
390
+
391
+ # Get initial position
392
+ try:
393
+ initial_state = self._extract_current_state()
394
+ initial_position = (
395
+ initial_state.get("player_x", 0),
396
+ initial_state.get("player_y", 0),
397
+ )
398
+ initial_map = initial_state.get("map_id", 0)
399
+ except Exception as e:
400
+ logging.warning(f"Could not extract initial state for movement retry: {e}")
401
+ # Fall back to single press
402
+ self._press_button(button, frames)
403
+ return True
404
+
405
+ for attempt in range(max_attempts):
406
+ # Press the button
407
+ self._press_button(button, frames)
408
+
409
+ # Check if position changed
410
+ try:
411
+ new_state = self._extract_current_state()
412
+ new_position = (
413
+ new_state.get("player_x", 0),
414
+ new_state.get("player_y", 0),
415
+ )
416
+ new_map = new_state.get("map_id", 0)
417
+
418
+ # Movement successful if position or map changed
419
+ if new_position != initial_position or new_map != initial_map:
420
+ logging.debug(
421
+ f"Movement successful after {attempt + 1} attempts: {initial_position} -> {new_position}"
422
+ )
423
+ return True
424
+
425
+ except Exception as e:
426
+ logging.warning(
427
+ f"Could not extract state during movement retry attempt {attempt + 1}: {e}"
428
+ )
429
+ continue
430
+
431
+ # If we get here, movement didn't occur after max_attempts
432
+ logging.warning(
433
+ f"Movement button {button} pressed {max_attempts} times but no position change detected"
434
+ )
435
+ return False
436
+
437
+ else:
438
+ # For all other buttons (A, B, START, SELECT), just press once
439
+ # No retry logic needed - let the game handle the response naturally
440
+ self._press_button(button, frames)
441
+ return True
442
+
443
+ def _create_states(
444
+ self, reward: float, terminated: bool = False
445
+ ) -> tuple[PokemonRedPrivateState, PokemonRedPublicState]:
446
+ """Create private and public state objects"""
447
+ try:
448
+ current_state = self._extract_current_state()
449
+ except Exception as e:
450
+ logging.error(f"Error extracting game state: {e}")
451
+ # Provide default state values
452
+ current_state = {
453
+ "map_id": 0,
454
+ "player_x": 0,
455
+ "player_y": 0,
456
+ "badges": 0,
457
+ "in_battle": False,
458
+ "party_pokemon": [],
459
+ "inventory_items": [],
460
+ "money": 0,
461
+ "battle_outcome": 0,
462
+ "menu_state": 0,
463
+ "text_box_active": False,
464
+ "warp_flag": 0,
465
+ }
466
+
467
+ try:
468
+ private_state = PokemonRedPrivateState(
469
+ reward_last_step=reward,
470
+ total_reward=self._total_reward,
471
+ terminated=terminated,
472
+ truncated=False,
473
+ step_count=self._step_count,
474
+ )
475
+
476
+ # Extract comprehensive game state data
477
+ map_id = int(current_state.get("map_id", 0))
478
+ player_x = int(current_state.get("player_x", 0))
479
+ player_y = int(current_state.get("player_y", 0))
480
+ badges = int(current_state.get("badges", 0))
481
+ money = int(current_state.get("money", 0))
482
+
483
+ # Count badges for badge_count field
484
+ badge_count = bin(badges).count("1")
485
+
486
+ # Create Pokemon party from detailed party data
487
+ party_pokemon_data = current_state.get("party_pokemon", [])
488
+ party = []
489
+ for pokemon_data in party_pokemon_data:
490
+ try:
491
+ pokemon = PokemonData(
492
+ species_id=int(pokemon_data.get("species_id", 0)),
493
+ level=int(pokemon_data.get("level", 1)),
494
+ hp_current=int(pokemon_data.get("hp_current", 1)),
495
+ hp_max=int(pokemon_data.get("hp_max", 1)),
496
+ xp=int(pokemon_data.get("xp", 0)),
497
+ hp_percentage=float(pokemon_data.get("hp_percentage", 100.0)),
498
+ )
499
+ party.append(pokemon)
500
+ except (TypeError, ValueError) as e:
501
+ logging.warning(f"Error creating Pokemon data: {e}")
502
+ continue
503
+
504
+ # Create inventory from detailed inventory data
505
+ inventory_data = current_state.get("inventory_items", [])
506
+ inventory = []
507
+ for item_data in inventory_data:
508
+ try:
509
+ item = InventoryItem(
510
+ item_id=int(item_data.get("item_id", 0)),
511
+ quantity=int(item_data.get("quantity", 0)),
512
+ )
513
+ inventory.append(item)
514
+ except (TypeError, ValueError) as e:
515
+ logging.warning(f"Error creating inventory item: {e}")
516
+ continue
517
+
518
+ # Create comprehensive public state
519
+ public_state = PokemonRedPublicState(
520
+ world=GameWorldState(map_id=map_id, player_x=player_x, player_y=player_y),
521
+ progress=PlayerProgressState(
522
+ badges=badges,
523
+ badge_count=badge_count,
524
+ money=money,
525
+ step_count=self._step_count,
526
+ ),
527
+ party=party,
528
+ inventory=inventory,
529
+ system=GameSystemState(
530
+ in_battle=bool(current_state.get("in_battle", False)),
531
+ battle_outcome=int(current_state.get("battle_outcome", 0)),
532
+ menu_state=int(current_state.get("menu_state", 0)),
533
+ text_box_active=bool(current_state.get("text_box_active", False)),
534
+ warp_flag=int(current_state.get("warp_flag", 0)),
535
+ ),
536
+ )
537
+
538
+ except (TypeError, ValueError) as e:
539
+ logging.error(f"Error creating states with data {current_state}: {e}")
540
+ # Create minimal safe states
541
+ private_state = PokemonRedPrivateState(
542
+ reward_last_step=0.0,
543
+ total_reward=0.0,
544
+ terminated=True,
545
+ truncated=False,
546
+ step_count=self._step_count,
547
+ )
548
+ public_state = PokemonRedPublicState(
549
+ world=GameWorldState(map_id=0, player_x=0, player_y=0),
550
+ progress=PlayerProgressState(
551
+ badges=0, badge_count=0, money=0, step_count=self._step_count
552
+ ),
553
+ party=[],
554
+ inventory=[],
555
+ system=GameSystemState(
556
+ in_battle=False,
557
+ battle_outcome=0,
558
+ menu_state=0,
559
+ text_box_active=False,
560
+ warp_flag=0,
561
+ ),
562
+ error_info=f"State creation error: {e}",
563
+ )
564
+
565
+ return private_state, public_state
566
+
567
+ async def _reset_engine(
568
+ self, *, seed: Optional[int] = None
569
+ ) -> tuple[PokemonRedPrivateState, PokemonRedPublicState]:
570
+ """Reset the Pokemon Red engine to initial state"""
571
+ # Load initial save state if provided
572
+ if (
573
+ hasattr(self.task_instance, "initial_engine_snapshot")
574
+ and self.task_instance.initial_engine_snapshot
575
+ ):
576
+ snapshot_path = self.task_instance.initial_engine_snapshot
577
+ if isinstance(snapshot_path, Path) and snapshot_path.exists():
578
+ self.emulator.load_state(str(snapshot_path))
579
+
580
+ self._total_reward = 0.0
581
+ self._step_count = 0
582
+ self._previous_state = self._extract_current_state()
583
+
584
+ return self._create_states(reward=0.0)
585
+
586
+ async def _step_engine(
587
+ self, action: Dict[str, Any]
588
+ ) -> tuple[PokemonRedPrivateState, PokemonRedPublicState]:
589
+ """Execute one step in the Pokemon Red environment"""
590
+ try:
591
+ # Extract previous state for reward calculation
592
+ prev_state = self._previous_state or self._extract_current_state()
593
+
594
+ # Execute action (button press)
595
+ button = action.get("button", "A")
596
+ frames = action.get("frames", 1)
597
+
598
+ self._press_button_with_retry(button, frames)
599
+
600
+ self._step_count += 1
601
+
602
+ # Extract new state
603
+ current_state = self._extract_current_state()
604
+
605
+ # Calculate reward using reward stack
606
+ try:
607
+ reward = await self.reward_stack.step_reward(
608
+ state=current_state,
609
+ action={
610
+ "prev_badges": int(prev_state.get("badges", 0)),
611
+ "prev_map_id": int(prev_state.get("map_id", 0)),
612
+ "prev_in_battle": bool(prev_state.get("in_battle", False)),
613
+ "prev_party_level": int(prev_state.get("party_level", 0)),
614
+ "prev_party_xp": int(prev_state.get("party_xp", 0)),
615
+ },
616
+ )
617
+ except Exception as e:
618
+ logging.error(f"Error calculating reward: {e}")
619
+ reward = -0.01 # Small penalty for error
620
+
621
+ self._total_reward += reward
622
+ self._previous_state = current_state
623
+
624
+ # Check termination condition (example: got Boulder Badge)
625
+ try:
626
+ badges = current_state.get("badges", 0)
627
+ badges = int(badges) if badges is not None else 0
628
+ terminated = (badges & 0x01) != 0
629
+ except (TypeError, ValueError) as e:
630
+ logging.error(
631
+ f"Error checking termination condition with badges={current_state.get('badges')}: {e}"
632
+ )
633
+ terminated = False
634
+
635
+ return self._create_states(reward=reward, terminated=terminated)
636
+
637
+ except Exception as e:
638
+ logging.error(f"Error in step engine: {e}")
639
+ # Still increment step count even on error
640
+ self._step_count += 1
641
+ # Return safe default states
642
+ return self._create_states(reward=-1.0, terminated=True)
643
+
644
+ async def _serialize_engine(self) -> PokemonRedEngineSnapshot:
645
+ """Serialize engine state for checkpointing"""
646
+ # Save state to temporary file
647
+ import tempfile
648
+
649
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".state")
650
+ temp_file.close()
651
+
652
+ if self.emulator is not None:
653
+ with open(temp_file.name, "wb") as f:
654
+ self.emulator.save_state(f)
655
+
656
+ # Read state file as bytes for storage
657
+ with open(temp_file.name, "rb") as f:
658
+ state_bytes = f.read()
659
+ else:
660
+ # For testing without emulator
661
+ state_bytes = b"mock_state_data"
662
+
663
+ current_state = self._extract_current_state()
664
+ current_state["_save_state_bytes"] = state_bytes
665
+
666
+ return PokemonRedEngineSnapshot(
667
+ state_data=current_state,
668
+ total_reward=self._total_reward,
669
+ step_count=self._step_count,
670
+ )
671
+
672
+ @classmethod
673
+ async def _deserialize_engine(
674
+ cls, snapshot: PokemonRedEngineSnapshot, task_instance: TaskInstance
675
+ ) -> "PokemonRedEngine":
676
+ """Deserialize engine from checkpoint"""
677
+ engine = cls(task_instance)
678
+
679
+ # Restore save state if available
680
+ if "_save_state_bytes" in snapshot.state_data and engine.emulator is not None:
681
+ import io
682
+
683
+ state_bytes = snapshot.state_data["_save_state_bytes"]
684
+ state_io = io.BytesIO(state_bytes)
685
+ engine.emulator.load_state(state_io)
686
+
687
+ engine._total_reward = snapshot.total_reward
688
+ engine._step_count = snapshot.step_count
689
+ engine._previous_state = {
690
+ k: v for k, v in snapshot.state_data.items() if k != "_save_state_bytes"
691
+ }
692
+
693
+ return engine
@@ -0,0 +1 @@
1
+ # Engine helpers for Pokemon Red