synth-ai 0.1.9__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +37 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.1.9.dist-info/METADATA +0 -37
- synth_ai-0.1.9.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,693 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import logging
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Dict, Any, Optional, List
|
5
|
+
from dataclasses import dataclass
|
6
|
+
|
7
|
+
# Import logging configuration first to suppress JAX debug messages
|
8
|
+
|
9
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
10
|
+
from synth_ai.environments.reproducibility.core import IReproducibleEngine
|
11
|
+
from synth_ai.environments.environment.rewards.core import RewardStack
|
12
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
13
|
+
|
14
|
+
from .engine_helpers.state_extraction import extract_game_state
|
15
|
+
from .engine_helpers.reward_components import (
|
16
|
+
BadgeRewardComponent,
|
17
|
+
MapTransitionComponent,
|
18
|
+
BattleVictoryComponent,
|
19
|
+
LevelUpComponent,
|
20
|
+
XPGainComponent,
|
21
|
+
StepPenaltyComponent,
|
22
|
+
)
|
23
|
+
|
24
|
+
try:
|
25
|
+
from pyboy import PyBoy
|
26
|
+
from pyboy.pyboy import WindowEvent
|
27
|
+
|
28
|
+
PYBOY_AVAILABLE = True
|
29
|
+
except ImportError:
|
30
|
+
PYBOY_AVAILABLE = False
|
31
|
+
PyBoy = None
|
32
|
+
WindowEvent = None
|
33
|
+
|
34
|
+
if not PYBOY_AVAILABLE:
|
35
|
+
|
36
|
+
class WindowEvent:
|
37
|
+
PRESS_BUTTON_A = 0
|
38
|
+
PRESS_BUTTON_B = 1
|
39
|
+
PRESS_ARROW_UP = 2
|
40
|
+
PRESS_ARROW_DOWN = 3
|
41
|
+
PRESS_ARROW_LEFT = 4
|
42
|
+
PRESS_ARROW_RIGHT = 5
|
43
|
+
PRESS_BUTTON_START = 6
|
44
|
+
PRESS_BUTTON_SELECT = 7
|
45
|
+
RELEASE_BUTTON_A = 8
|
46
|
+
RELEASE_BUTTON_B = 9
|
47
|
+
RELEASE_ARROW_UP = 10
|
48
|
+
RELEASE_ARROW_DOWN = 11
|
49
|
+
RELEASE_ARROW_LEFT = 12
|
50
|
+
RELEASE_ARROW_RIGHT = 13
|
51
|
+
RELEASE_BUTTON_START = 14
|
52
|
+
RELEASE_BUTTON_SELECT = 15
|
53
|
+
|
54
|
+
|
55
|
+
# Game Boy button mappings - PyBoy uses string names
|
56
|
+
BUTTON_MAP = {
|
57
|
+
"A": "a",
|
58
|
+
"B": "b",
|
59
|
+
"UP": "up",
|
60
|
+
"DOWN": "down",
|
61
|
+
"LEFT": "left",
|
62
|
+
"RIGHT": "right",
|
63
|
+
"START": "start",
|
64
|
+
"SELECT": "select",
|
65
|
+
}
|
66
|
+
|
67
|
+
|
68
|
+
@dataclass
|
69
|
+
class PokemonData:
|
70
|
+
"""Detailed Pokemon information"""
|
71
|
+
|
72
|
+
species_id: int
|
73
|
+
level: int
|
74
|
+
hp_current: int
|
75
|
+
hp_max: int
|
76
|
+
xp: int
|
77
|
+
hp_percentage: float
|
78
|
+
# TODO: Add when memory addresses are available
|
79
|
+
# attack: int = 0
|
80
|
+
# defense: int = 0
|
81
|
+
# speed: int = 0
|
82
|
+
# special: int = 0
|
83
|
+
# status_conditions: List[str] = None
|
84
|
+
# moves: List[str] = None
|
85
|
+
# nickname: str = ""
|
86
|
+
|
87
|
+
|
88
|
+
@dataclass
|
89
|
+
class InventoryItem:
|
90
|
+
"""Inventory item information"""
|
91
|
+
|
92
|
+
item_id: int
|
93
|
+
quantity: int
|
94
|
+
# TODO: Add when we have item name mapping
|
95
|
+
# name: str = ""
|
96
|
+
# category: str = ""
|
97
|
+
|
98
|
+
|
99
|
+
@dataclass
|
100
|
+
class GameWorldState:
|
101
|
+
"""Current world/map state information"""
|
102
|
+
|
103
|
+
map_id: int
|
104
|
+
player_x: int
|
105
|
+
player_y: int
|
106
|
+
# TODO: Add when available
|
107
|
+
# map_name: str = ""
|
108
|
+
# map_type: str = "" # town, route, building, dungeon
|
109
|
+
# available_services: List[str] = None # Pokemon Center, Pokemart, Gym, etc.
|
110
|
+
# npcs_nearby: List[str] = None
|
111
|
+
# items_on_ground: List[str] = None
|
112
|
+
# wild_encounters_available: bool = False
|
113
|
+
|
114
|
+
|
115
|
+
@dataclass
|
116
|
+
class GameSystemState:
|
117
|
+
"""Current game system state (menus, battles, etc.)"""
|
118
|
+
|
119
|
+
in_battle: bool
|
120
|
+
battle_outcome: int
|
121
|
+
menu_state: int
|
122
|
+
text_box_active: bool
|
123
|
+
warp_flag: int
|
124
|
+
# TODO: Add when available
|
125
|
+
# current_menu_type: str = ""
|
126
|
+
# dialogue_speaker: str = ""
|
127
|
+
# available_actions: List[str] = None
|
128
|
+
|
129
|
+
|
130
|
+
@dataclass
|
131
|
+
class PlayerProgressState:
|
132
|
+
"""Player progression and achievements"""
|
133
|
+
|
134
|
+
badges: int
|
135
|
+
badge_count: int
|
136
|
+
money: int
|
137
|
+
step_count: int
|
138
|
+
# TODO: Add when available
|
139
|
+
# pokedex_seen: int = 0
|
140
|
+
# pokedex_caught: int = 0
|
141
|
+
# story_flags: List[str] = None
|
142
|
+
# time_played: str = "00:00"
|
143
|
+
|
144
|
+
|
145
|
+
@dataclass
|
146
|
+
class PokemonRedPublicState:
|
147
|
+
"""Comprehensive Pokemon Red game state for text-based AI interaction
|
148
|
+
|
149
|
+
This structure provides rich, semantic game information to eliminate
|
150
|
+
the need for visual processing and enable strategic decision making.
|
151
|
+
Based on requirements from text_port.txt.
|
152
|
+
"""
|
153
|
+
|
154
|
+
# Core game world state
|
155
|
+
world: GameWorldState
|
156
|
+
|
157
|
+
# Player progress and achievements
|
158
|
+
progress: PlayerProgressState
|
159
|
+
|
160
|
+
# Pokemon party information (up to 6 Pokemon)
|
161
|
+
party: List[PokemonData]
|
162
|
+
|
163
|
+
# Inventory and items
|
164
|
+
inventory: List[InventoryItem]
|
165
|
+
|
166
|
+
# Current game system state
|
167
|
+
system: GameSystemState
|
168
|
+
|
169
|
+
# Error information
|
170
|
+
error_info: Optional[str] = None
|
171
|
+
|
172
|
+
# Legacy compatibility fields (for existing code)
|
173
|
+
@property
|
174
|
+
def map_id(self) -> int:
|
175
|
+
return self.world.map_id
|
176
|
+
|
177
|
+
@property
|
178
|
+
def player_x(self) -> int:
|
179
|
+
return self.world.player_x
|
180
|
+
|
181
|
+
@property
|
182
|
+
def player_y(self) -> int:
|
183
|
+
return self.world.player_y
|
184
|
+
|
185
|
+
@property
|
186
|
+
def badges(self) -> int:
|
187
|
+
return self.progress.badges
|
188
|
+
|
189
|
+
@property
|
190
|
+
def in_battle(self) -> bool:
|
191
|
+
return self.system.in_battle
|
192
|
+
|
193
|
+
@property
|
194
|
+
def party_level(self) -> int:
|
195
|
+
return self.party[0].level if self.party else 0
|
196
|
+
|
197
|
+
@property
|
198
|
+
def party_hp_current(self) -> int:
|
199
|
+
return self.party[0].hp_current if self.party else 0
|
200
|
+
|
201
|
+
@property
|
202
|
+
def party_hp_max(self) -> int:
|
203
|
+
return self.party[0].hp_max if self.party else 0
|
204
|
+
|
205
|
+
@property
|
206
|
+
def party_xp(self) -> int:
|
207
|
+
return self.party[0].xp if self.party else 0
|
208
|
+
|
209
|
+
@property
|
210
|
+
def step_count(self) -> int:
|
211
|
+
return self.progress.step_count
|
212
|
+
|
213
|
+
|
214
|
+
@dataclass
|
215
|
+
class PokemonRedPrivateState:
|
216
|
+
reward_last_step: float
|
217
|
+
total_reward: float
|
218
|
+
terminated: bool
|
219
|
+
truncated: bool
|
220
|
+
step_count: int
|
221
|
+
|
222
|
+
|
223
|
+
class PokemonRedEngineSnapshot(StatefulEngineSnapshot):
|
224
|
+
def __init__(self, state_data: Dict[str, Any], total_reward: float, step_count: int):
|
225
|
+
self.state_data = state_data
|
226
|
+
self.total_reward = total_reward
|
227
|
+
self.step_count = step_count
|
228
|
+
|
229
|
+
def model_dump(self) -> Dict[str, Any]:
|
230
|
+
return {
|
231
|
+
"state_data": self.state_data,
|
232
|
+
"total_reward": self.total_reward,
|
233
|
+
"step_count": self.step_count,
|
234
|
+
}
|
235
|
+
|
236
|
+
|
237
|
+
class PokemonRedEngine(StatefulEngine, IReproducibleEngine):
|
238
|
+
"""Pokemon Red game engine with dense reward tracking"""
|
239
|
+
|
240
|
+
def __init__(self, task_instance: TaskInstance, skip_rom_check: bool = False):
|
241
|
+
self.task_instance = task_instance
|
242
|
+
|
243
|
+
# Initialize PyBoy emulator
|
244
|
+
if not skip_rom_check:
|
245
|
+
if not PYBOY_AVAILABLE:
|
246
|
+
raise ImportError("PyBoy is required but not installed. Run: uv add pyboy")
|
247
|
+
|
248
|
+
rom_path = self._get_rom_path()
|
249
|
+
if not rom_path.exists():
|
250
|
+
raise FileNotFoundError(
|
251
|
+
f"Pokemon Red ROM not found at {rom_path}. Please see README.md for setup instructions."
|
252
|
+
)
|
253
|
+
|
254
|
+
self.emulator = PyBoy(str(rom_path), window="null")
|
255
|
+
|
256
|
+
# Load the working init state to get the game into a playable state
|
257
|
+
self._load_init_state()
|
258
|
+
else:
|
259
|
+
# For testing purposes, use None emulator
|
260
|
+
self.emulator = None
|
261
|
+
|
262
|
+
# Initialize reward stack with dense components
|
263
|
+
self.reward_stack = RewardStack(
|
264
|
+
components=[
|
265
|
+
BadgeRewardComponent(),
|
266
|
+
MapTransitionComponent(),
|
267
|
+
BattleVictoryComponent(),
|
268
|
+
LevelUpComponent(),
|
269
|
+
XPGainComponent(),
|
270
|
+
StepPenaltyComponent(),
|
271
|
+
]
|
272
|
+
)
|
273
|
+
|
274
|
+
self._total_reward = 0.0
|
275
|
+
self._step_count = 0
|
276
|
+
self._previous_state: Optional[Dict[str, Any]] = None
|
277
|
+
|
278
|
+
def _get_rom_path(self) -> Path:
|
279
|
+
"""Get path to Pokemon Red ROM file"""
|
280
|
+
# Check several possible locations
|
281
|
+
possible_paths = [
|
282
|
+
Path(__file__).parent / "roms" / "pokemon_red.gb",
|
283
|
+
Path(__file__).parent / "roms" / "PokemonRed.gb",
|
284
|
+
Path(__file__).parent / "vendor" / "pokemon_red.gb",
|
285
|
+
Path.home() / "Games" / "pokemon_red.gb",
|
286
|
+
]
|
287
|
+
|
288
|
+
for path in possible_paths:
|
289
|
+
if path.exists():
|
290
|
+
return path
|
291
|
+
|
292
|
+
# Return default expected location
|
293
|
+
return Path(__file__).parent / "roms" / "pokemon_red.gb"
|
294
|
+
|
295
|
+
def _load_init_state(self) -> None:
|
296
|
+
"""Load the initial save state to get the game into a playable state"""
|
297
|
+
init_state_paths = [
|
298
|
+
Path(__file__).parent / "roms" / "working_init.state",
|
299
|
+
Path(__file__).parent / "roms" / "init.state",
|
300
|
+
]
|
301
|
+
|
302
|
+
for state_path in init_state_paths:
|
303
|
+
if state_path.exists():
|
304
|
+
try:
|
305
|
+
with open(state_path, "rb") as f:
|
306
|
+
self.emulator.load_state(f)
|
307
|
+
logging.info(f"Loaded init state from: {state_path}")
|
308
|
+
return
|
309
|
+
except Exception as e:
|
310
|
+
logging.warning(f"Failed to load init state from {state_path}: {e}")
|
311
|
+
continue
|
312
|
+
|
313
|
+
# If no init state found, try to use PyBoy's game wrapper
|
314
|
+
logging.warning("No init state found, trying PyBoy game wrapper...")
|
315
|
+
try:
|
316
|
+
if hasattr(self.emulator.game_wrapper, "start_game"):
|
317
|
+
self.emulator.game_wrapper.start_game()
|
318
|
+
logging.info("Used PyBoy game wrapper start_game()")
|
319
|
+
else:
|
320
|
+
logging.warning("PyBoy game wrapper doesn't have start_game method")
|
321
|
+
except Exception as e:
|
322
|
+
logging.warning(f"PyBoy game wrapper start_game failed: {e}")
|
323
|
+
|
324
|
+
def _extract_current_state(self) -> Dict[str, Any]:
|
325
|
+
"""Extract current game state from emulator memory"""
|
326
|
+
if self.emulator is None:
|
327
|
+
# Return mock state for testing
|
328
|
+
return {
|
329
|
+
"map_id": 1,
|
330
|
+
"player_x": 10,
|
331
|
+
"player_y": 10,
|
332
|
+
"badges": 0,
|
333
|
+
"in_battle": False,
|
334
|
+
"party_level": 5,
|
335
|
+
"party_hp_current": 25,
|
336
|
+
"party_hp_max": 25,
|
337
|
+
"party_xp": 100,
|
338
|
+
}
|
339
|
+
|
340
|
+
# Get memory from PyBoy
|
341
|
+
memory = self.emulator.memory
|
342
|
+
return extract_game_state(memory)
|
343
|
+
|
344
|
+
def _press_button(self, button: str, frames: int = 1):
|
345
|
+
"""Press a Game Boy button for specified frames"""
|
346
|
+
if button not in BUTTON_MAP:
|
347
|
+
raise ValueError(f"Invalid button: {button}. Valid buttons: {list(BUTTON_MAP.keys())}")
|
348
|
+
|
349
|
+
button_name = BUTTON_MAP[button]
|
350
|
+
|
351
|
+
if self.emulator is None:
|
352
|
+
return # Skip for testing
|
353
|
+
|
354
|
+
# Press button
|
355
|
+
self.emulator.button_press(button_name)
|
356
|
+
|
357
|
+
# Hold for specified frames
|
358
|
+
for _ in range(frames):
|
359
|
+
self.emulator.tick()
|
360
|
+
|
361
|
+
# Release button
|
362
|
+
self.emulator.button_release(button_name)
|
363
|
+
|
364
|
+
# Let release take effect
|
365
|
+
self.emulator.tick()
|
366
|
+
|
367
|
+
def _press_button_with_retry(
|
368
|
+
self, button: str, frames: int = 1, max_attempts: int = 10
|
369
|
+
) -> bool:
|
370
|
+
"""
|
371
|
+
Press a button with automatic retry for movement commands.
|
372
|
+
|
373
|
+
For movement buttons (UP, DOWN, LEFT, RIGHT), this will automatically
|
374
|
+
repeat the button press until movement occurs or max_attempts is reached.
|
375
|
+
|
376
|
+
For other buttons (A, B, START, SELECT), this behaves like _press_button.
|
377
|
+
|
378
|
+
Note: Previous menu-closing logic for 'B' button was removed because
|
379
|
+
investigation showed that menu_state memory address represents
|
380
|
+
"selected menu item index" not "menu is open", leading to false positives.
|
381
|
+
|
382
|
+
Returns True if the expected state change occurred or always True for non-retryable buttons.
|
383
|
+
"""
|
384
|
+
movement_buttons = {"UP", "DOWN", "LEFT", "RIGHT"}
|
385
|
+
|
386
|
+
# Handle movement buttons with retry until position changes
|
387
|
+
if button in movement_buttons:
|
388
|
+
if self.emulator is None:
|
389
|
+
return True # Skip for testing
|
390
|
+
|
391
|
+
# Get initial position
|
392
|
+
try:
|
393
|
+
initial_state = self._extract_current_state()
|
394
|
+
initial_position = (
|
395
|
+
initial_state.get("player_x", 0),
|
396
|
+
initial_state.get("player_y", 0),
|
397
|
+
)
|
398
|
+
initial_map = initial_state.get("map_id", 0)
|
399
|
+
except Exception as e:
|
400
|
+
logging.warning(f"Could not extract initial state for movement retry: {e}")
|
401
|
+
# Fall back to single press
|
402
|
+
self._press_button(button, frames)
|
403
|
+
return True
|
404
|
+
|
405
|
+
for attempt in range(max_attempts):
|
406
|
+
# Press the button
|
407
|
+
self._press_button(button, frames)
|
408
|
+
|
409
|
+
# Check if position changed
|
410
|
+
try:
|
411
|
+
new_state = self._extract_current_state()
|
412
|
+
new_position = (
|
413
|
+
new_state.get("player_x", 0),
|
414
|
+
new_state.get("player_y", 0),
|
415
|
+
)
|
416
|
+
new_map = new_state.get("map_id", 0)
|
417
|
+
|
418
|
+
# Movement successful if position or map changed
|
419
|
+
if new_position != initial_position or new_map != initial_map:
|
420
|
+
logging.debug(
|
421
|
+
f"Movement successful after {attempt + 1} attempts: {initial_position} -> {new_position}"
|
422
|
+
)
|
423
|
+
return True
|
424
|
+
|
425
|
+
except Exception as e:
|
426
|
+
logging.warning(
|
427
|
+
f"Could not extract state during movement retry attempt {attempt + 1}: {e}"
|
428
|
+
)
|
429
|
+
continue
|
430
|
+
|
431
|
+
# If we get here, movement didn't occur after max_attempts
|
432
|
+
logging.warning(
|
433
|
+
f"Movement button {button} pressed {max_attempts} times but no position change detected"
|
434
|
+
)
|
435
|
+
return False
|
436
|
+
|
437
|
+
else:
|
438
|
+
# For all other buttons (A, B, START, SELECT), just press once
|
439
|
+
# No retry logic needed - let the game handle the response naturally
|
440
|
+
self._press_button(button, frames)
|
441
|
+
return True
|
442
|
+
|
443
|
+
def _create_states(
|
444
|
+
self, reward: float, terminated: bool = False
|
445
|
+
) -> tuple[PokemonRedPrivateState, PokemonRedPublicState]:
|
446
|
+
"""Create private and public state objects"""
|
447
|
+
try:
|
448
|
+
current_state = self._extract_current_state()
|
449
|
+
except Exception as e:
|
450
|
+
logging.error(f"Error extracting game state: {e}")
|
451
|
+
# Provide default state values
|
452
|
+
current_state = {
|
453
|
+
"map_id": 0,
|
454
|
+
"player_x": 0,
|
455
|
+
"player_y": 0,
|
456
|
+
"badges": 0,
|
457
|
+
"in_battle": False,
|
458
|
+
"party_pokemon": [],
|
459
|
+
"inventory_items": [],
|
460
|
+
"money": 0,
|
461
|
+
"battle_outcome": 0,
|
462
|
+
"menu_state": 0,
|
463
|
+
"text_box_active": False,
|
464
|
+
"warp_flag": 0,
|
465
|
+
}
|
466
|
+
|
467
|
+
try:
|
468
|
+
private_state = PokemonRedPrivateState(
|
469
|
+
reward_last_step=reward,
|
470
|
+
total_reward=self._total_reward,
|
471
|
+
terminated=terminated,
|
472
|
+
truncated=False,
|
473
|
+
step_count=self._step_count,
|
474
|
+
)
|
475
|
+
|
476
|
+
# Extract comprehensive game state data
|
477
|
+
map_id = int(current_state.get("map_id", 0))
|
478
|
+
player_x = int(current_state.get("player_x", 0))
|
479
|
+
player_y = int(current_state.get("player_y", 0))
|
480
|
+
badges = int(current_state.get("badges", 0))
|
481
|
+
money = int(current_state.get("money", 0))
|
482
|
+
|
483
|
+
# Count badges for badge_count field
|
484
|
+
badge_count = bin(badges).count("1")
|
485
|
+
|
486
|
+
# Create Pokemon party from detailed party data
|
487
|
+
party_pokemon_data = current_state.get("party_pokemon", [])
|
488
|
+
party = []
|
489
|
+
for pokemon_data in party_pokemon_data:
|
490
|
+
try:
|
491
|
+
pokemon = PokemonData(
|
492
|
+
species_id=int(pokemon_data.get("species_id", 0)),
|
493
|
+
level=int(pokemon_data.get("level", 1)),
|
494
|
+
hp_current=int(pokemon_data.get("hp_current", 1)),
|
495
|
+
hp_max=int(pokemon_data.get("hp_max", 1)),
|
496
|
+
xp=int(pokemon_data.get("xp", 0)),
|
497
|
+
hp_percentage=float(pokemon_data.get("hp_percentage", 100.0)),
|
498
|
+
)
|
499
|
+
party.append(pokemon)
|
500
|
+
except (TypeError, ValueError) as e:
|
501
|
+
logging.warning(f"Error creating Pokemon data: {e}")
|
502
|
+
continue
|
503
|
+
|
504
|
+
# Create inventory from detailed inventory data
|
505
|
+
inventory_data = current_state.get("inventory_items", [])
|
506
|
+
inventory = []
|
507
|
+
for item_data in inventory_data:
|
508
|
+
try:
|
509
|
+
item = InventoryItem(
|
510
|
+
item_id=int(item_data.get("item_id", 0)),
|
511
|
+
quantity=int(item_data.get("quantity", 0)),
|
512
|
+
)
|
513
|
+
inventory.append(item)
|
514
|
+
except (TypeError, ValueError) as e:
|
515
|
+
logging.warning(f"Error creating inventory item: {e}")
|
516
|
+
continue
|
517
|
+
|
518
|
+
# Create comprehensive public state
|
519
|
+
public_state = PokemonRedPublicState(
|
520
|
+
world=GameWorldState(map_id=map_id, player_x=player_x, player_y=player_y),
|
521
|
+
progress=PlayerProgressState(
|
522
|
+
badges=badges,
|
523
|
+
badge_count=badge_count,
|
524
|
+
money=money,
|
525
|
+
step_count=self._step_count,
|
526
|
+
),
|
527
|
+
party=party,
|
528
|
+
inventory=inventory,
|
529
|
+
system=GameSystemState(
|
530
|
+
in_battle=bool(current_state.get("in_battle", False)),
|
531
|
+
battle_outcome=int(current_state.get("battle_outcome", 0)),
|
532
|
+
menu_state=int(current_state.get("menu_state", 0)),
|
533
|
+
text_box_active=bool(current_state.get("text_box_active", False)),
|
534
|
+
warp_flag=int(current_state.get("warp_flag", 0)),
|
535
|
+
),
|
536
|
+
)
|
537
|
+
|
538
|
+
except (TypeError, ValueError) as e:
|
539
|
+
logging.error(f"Error creating states with data {current_state}: {e}")
|
540
|
+
# Create minimal safe states
|
541
|
+
private_state = PokemonRedPrivateState(
|
542
|
+
reward_last_step=0.0,
|
543
|
+
total_reward=0.0,
|
544
|
+
terminated=True,
|
545
|
+
truncated=False,
|
546
|
+
step_count=self._step_count,
|
547
|
+
)
|
548
|
+
public_state = PokemonRedPublicState(
|
549
|
+
world=GameWorldState(map_id=0, player_x=0, player_y=0),
|
550
|
+
progress=PlayerProgressState(
|
551
|
+
badges=0, badge_count=0, money=0, step_count=self._step_count
|
552
|
+
),
|
553
|
+
party=[],
|
554
|
+
inventory=[],
|
555
|
+
system=GameSystemState(
|
556
|
+
in_battle=False,
|
557
|
+
battle_outcome=0,
|
558
|
+
menu_state=0,
|
559
|
+
text_box_active=False,
|
560
|
+
warp_flag=0,
|
561
|
+
),
|
562
|
+
error_info=f"State creation error: {e}",
|
563
|
+
)
|
564
|
+
|
565
|
+
return private_state, public_state
|
566
|
+
|
567
|
+
async def _reset_engine(
|
568
|
+
self, *, seed: Optional[int] = None
|
569
|
+
) -> tuple[PokemonRedPrivateState, PokemonRedPublicState]:
|
570
|
+
"""Reset the Pokemon Red engine to initial state"""
|
571
|
+
# Load initial save state if provided
|
572
|
+
if (
|
573
|
+
hasattr(self.task_instance, "initial_engine_snapshot")
|
574
|
+
and self.task_instance.initial_engine_snapshot
|
575
|
+
):
|
576
|
+
snapshot_path = self.task_instance.initial_engine_snapshot
|
577
|
+
if isinstance(snapshot_path, Path) and snapshot_path.exists():
|
578
|
+
self.emulator.load_state(str(snapshot_path))
|
579
|
+
|
580
|
+
self._total_reward = 0.0
|
581
|
+
self._step_count = 0
|
582
|
+
self._previous_state = self._extract_current_state()
|
583
|
+
|
584
|
+
return self._create_states(reward=0.0)
|
585
|
+
|
586
|
+
async def _step_engine(
|
587
|
+
self, action: Dict[str, Any]
|
588
|
+
) -> tuple[PokemonRedPrivateState, PokemonRedPublicState]:
|
589
|
+
"""Execute one step in the Pokemon Red environment"""
|
590
|
+
try:
|
591
|
+
# Extract previous state for reward calculation
|
592
|
+
prev_state = self._previous_state or self._extract_current_state()
|
593
|
+
|
594
|
+
# Execute action (button press)
|
595
|
+
button = action.get("button", "A")
|
596
|
+
frames = action.get("frames", 1)
|
597
|
+
|
598
|
+
self._press_button_with_retry(button, frames)
|
599
|
+
|
600
|
+
self._step_count += 1
|
601
|
+
|
602
|
+
# Extract new state
|
603
|
+
current_state = self._extract_current_state()
|
604
|
+
|
605
|
+
# Calculate reward using reward stack
|
606
|
+
try:
|
607
|
+
reward = await self.reward_stack.step_reward(
|
608
|
+
state=current_state,
|
609
|
+
action={
|
610
|
+
"prev_badges": int(prev_state.get("badges", 0)),
|
611
|
+
"prev_map_id": int(prev_state.get("map_id", 0)),
|
612
|
+
"prev_in_battle": bool(prev_state.get("in_battle", False)),
|
613
|
+
"prev_party_level": int(prev_state.get("party_level", 0)),
|
614
|
+
"prev_party_xp": int(prev_state.get("party_xp", 0)),
|
615
|
+
},
|
616
|
+
)
|
617
|
+
except Exception as e:
|
618
|
+
logging.error(f"Error calculating reward: {e}")
|
619
|
+
reward = -0.01 # Small penalty for error
|
620
|
+
|
621
|
+
self._total_reward += reward
|
622
|
+
self._previous_state = current_state
|
623
|
+
|
624
|
+
# Check termination condition (example: got Boulder Badge)
|
625
|
+
try:
|
626
|
+
badges = current_state.get("badges", 0)
|
627
|
+
badges = int(badges) if badges is not None else 0
|
628
|
+
terminated = (badges & 0x01) != 0
|
629
|
+
except (TypeError, ValueError) as e:
|
630
|
+
logging.error(
|
631
|
+
f"Error checking termination condition with badges={current_state.get('badges')}: {e}"
|
632
|
+
)
|
633
|
+
terminated = False
|
634
|
+
|
635
|
+
return self._create_states(reward=reward, terminated=terminated)
|
636
|
+
|
637
|
+
except Exception as e:
|
638
|
+
logging.error(f"Error in step engine: {e}")
|
639
|
+
# Still increment step count even on error
|
640
|
+
self._step_count += 1
|
641
|
+
# Return safe default states
|
642
|
+
return self._create_states(reward=-1.0, terminated=True)
|
643
|
+
|
644
|
+
async def _serialize_engine(self) -> PokemonRedEngineSnapshot:
|
645
|
+
"""Serialize engine state for checkpointing"""
|
646
|
+
# Save state to temporary file
|
647
|
+
import tempfile
|
648
|
+
|
649
|
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".state")
|
650
|
+
temp_file.close()
|
651
|
+
|
652
|
+
if self.emulator is not None:
|
653
|
+
with open(temp_file.name, "wb") as f:
|
654
|
+
self.emulator.save_state(f)
|
655
|
+
|
656
|
+
# Read state file as bytes for storage
|
657
|
+
with open(temp_file.name, "rb") as f:
|
658
|
+
state_bytes = f.read()
|
659
|
+
else:
|
660
|
+
# For testing without emulator
|
661
|
+
state_bytes = b"mock_state_data"
|
662
|
+
|
663
|
+
current_state = self._extract_current_state()
|
664
|
+
current_state["_save_state_bytes"] = state_bytes
|
665
|
+
|
666
|
+
return PokemonRedEngineSnapshot(
|
667
|
+
state_data=current_state,
|
668
|
+
total_reward=self._total_reward,
|
669
|
+
step_count=self._step_count,
|
670
|
+
)
|
671
|
+
|
672
|
+
@classmethod
|
673
|
+
async def _deserialize_engine(
|
674
|
+
cls, snapshot: PokemonRedEngineSnapshot, task_instance: TaskInstance
|
675
|
+
) -> "PokemonRedEngine":
|
676
|
+
"""Deserialize engine from checkpoint"""
|
677
|
+
engine = cls(task_instance)
|
678
|
+
|
679
|
+
# Restore save state if available
|
680
|
+
if "_save_state_bytes" in snapshot.state_data and engine.emulator is not None:
|
681
|
+
import io
|
682
|
+
|
683
|
+
state_bytes = snapshot.state_data["_save_state_bytes"]
|
684
|
+
state_io = io.BytesIO(state_bytes)
|
685
|
+
engine.emulator.load_state(state_io)
|
686
|
+
|
687
|
+
engine._total_reward = snapshot.total_reward
|
688
|
+
engine._step_count = snapshot.step_count
|
689
|
+
engine._previous_state = {
|
690
|
+
k: v for k, v in snapshot.state_data.items() if k != "_save_state_bytes"
|
691
|
+
}
|
692
|
+
|
693
|
+
return engine
|
@@ -0,0 +1 @@
|
|
1
|
+
# Engine helpers for Pokemon Red
|