synth-ai 0.1.9__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +37 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.1.9.dist-info/METADATA +0 -37
- synth_ai-0.1.9.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,448 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
mcts_pokemon_red_env_example.py
|
4
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
5
|
+
Monte-Carlo-Tree-Search demo for Pokemon Red that:
|
6
|
+
• wraps Pokemon Red environment with real ROM
|
7
|
+
• stores every state in a FilesystemSnapshotStore
|
8
|
+
• expands / rolls-out with a TrajectoryTreeStore
|
9
|
+
• uses simple heuristics to guide exploration
|
10
|
+
• returns action sequence that makes progress
|
11
|
+
|
12
|
+
Run with pytest: pytest src/examples/red/units/test_tree.py
|
13
|
+
"""
|
14
|
+
|
15
|
+
import asyncio
|
16
|
+
import gzip
|
17
|
+
import pickle
|
18
|
+
import random
|
19
|
+
import time
|
20
|
+
import logging
|
21
|
+
from pathlib import Path
|
22
|
+
|
23
|
+
import pytest
|
24
|
+
|
25
|
+
import sys
|
26
|
+
|
27
|
+
sys.path.append("/Users/joshuapurtell/Documents/GitHub/Environments/src")
|
28
|
+
|
29
|
+
from synth_ai.environments.reproducibility.tree import FilesystemSnapshotStore, TrajectoryTreeStore
|
30
|
+
from synth_ai.environments.examples.red.taskset import (
|
31
|
+
INSTANCE as DEFAULT_TASK,
|
32
|
+
)
|
33
|
+
from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
|
34
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
35
|
+
|
36
|
+
logging.basicConfig(level=logging.DEBUG, format="%(message)s")
|
37
|
+
LOG = logging.getLogger("pokemon-mcts")
|
38
|
+
|
39
|
+
# Pokemon Red action space - all possible buttons
|
40
|
+
POKEMON_ACTIONS = ["A", "B", "UP", "DOWN", "LEFT", "RIGHT", "START", "SELECT"]
|
41
|
+
|
42
|
+
|
43
|
+
def heuristic_score(env: PokemonRedEnvironment) -> float:
|
44
|
+
"""
|
45
|
+
Simple heuristic to evaluate Pokemon Red game state.
|
46
|
+
Higher scores are better.
|
47
|
+
"""
|
48
|
+
try:
|
49
|
+
# Get current state
|
50
|
+
priv, pub = env.engine._create_states(reward=0.0)
|
51
|
+
|
52
|
+
score = 10.0 # Base score to avoid all zeros
|
53
|
+
|
54
|
+
# Badge progress (most important)
|
55
|
+
badge_count = bin(pub.badges).count("1")
|
56
|
+
score += badge_count * 100.0 # 100 points per badge
|
57
|
+
|
58
|
+
# Level progress
|
59
|
+
score += pub.party_level * 5.0 # 5 points per level
|
60
|
+
|
61
|
+
# XP progress (smaller contribution)
|
62
|
+
score += pub.party_xp * 0.001 # Very small XP bonus
|
63
|
+
|
64
|
+
# Exploration bonus - being in different maps
|
65
|
+
if pub.map_id > 0:
|
66
|
+
score += 10.0 # Bonus for being in actual game world
|
67
|
+
|
68
|
+
# Position exploration bonus - reward movement from (0,0)
|
69
|
+
if pub.player_x != 0 or pub.player_y != 0:
|
70
|
+
score += 5.0
|
71
|
+
|
72
|
+
# HP bonus - encourage keeping Pokemon healthy (only if we have a Pokemon)
|
73
|
+
if pub.party_hp_max > 0:
|
74
|
+
hp_ratio = pub.party_hp_current / pub.party_hp_max
|
75
|
+
score += hp_ratio * 2.0
|
76
|
+
else:
|
77
|
+
# No penalty for not having a Pokemon initially
|
78
|
+
score += 1.0
|
79
|
+
|
80
|
+
# Step efficiency penalty (very small)
|
81
|
+
score -= pub.step_count * 0.001
|
82
|
+
|
83
|
+
return max(score, 0.1) # Ensure minimum positive score
|
84
|
+
|
85
|
+
except Exception as e:
|
86
|
+
LOG.debug(f"Heuristic evaluation error: {e}")
|
87
|
+
return 0.1
|
88
|
+
|
89
|
+
|
90
|
+
def is_terminal_state(env: PokemonRedEnvironment) -> bool:
|
91
|
+
"""Check if we've reached a terminal state (won or lost)"""
|
92
|
+
try:
|
93
|
+
priv, pub = env.engine._create_states(reward=0.0)
|
94
|
+
|
95
|
+
# Terminal if we got the Boulder Badge (task completion)
|
96
|
+
if pub.badges & 0x01:
|
97
|
+
return True
|
98
|
+
|
99
|
+
# Only consider terminal if HP is 0 AND max HP > 0 (meaning we actually have a Pokemon)
|
100
|
+
# Initial state might have 0/0 HP which isn't really a loss
|
101
|
+
if pub.party_hp_current == 0 and pub.party_hp_max > 0:
|
102
|
+
return True
|
103
|
+
|
104
|
+
return False
|
105
|
+
|
106
|
+
except Exception:
|
107
|
+
return True # Consider error states as terminal
|
108
|
+
|
109
|
+
|
110
|
+
async def simple_rollout(env: PokemonRedEnvironment, max_steps: int = 20) -> float:
|
111
|
+
"""
|
112
|
+
Perform a simple random rollout from current state.
|
113
|
+
Returns heuristic score after rollout.
|
114
|
+
"""
|
115
|
+
try:
|
116
|
+
# Save current state
|
117
|
+
snapshot = await env._serialize_engine()
|
118
|
+
|
119
|
+
# Random walk
|
120
|
+
for _ in range(max_steps):
|
121
|
+
if is_terminal_state(env):
|
122
|
+
break
|
123
|
+
|
124
|
+
# Choose random action
|
125
|
+
action = random.choice(POKEMON_ACTIONS)
|
126
|
+
call = EnvToolCall(tool="press_button", args={"button": action, "frames": 1})
|
127
|
+
|
128
|
+
try:
|
129
|
+
await env.step(call)
|
130
|
+
except Exception:
|
131
|
+
break # Stop on error
|
132
|
+
|
133
|
+
# Evaluate final state
|
134
|
+
final_score = heuristic_score(env)
|
135
|
+
|
136
|
+
# Restore original state
|
137
|
+
env.engine = await PokemonRedEnvironment._deserialize_engine(snapshot, env.task_instance)
|
138
|
+
|
139
|
+
return final_score
|
140
|
+
|
141
|
+
except Exception as e:
|
142
|
+
LOG.debug(f"Rollout error: {e}")
|
143
|
+
return 0.0
|
144
|
+
|
145
|
+
|
146
|
+
async def pokemon_red_mcts_plan(
|
147
|
+
tree: TrajectoryTreeStore,
|
148
|
+
root_id: str,
|
149
|
+
*,
|
150
|
+
rollouts_per_action: int = 10,
|
151
|
+
max_depth: int = 20,
|
152
|
+
timeout_s: float = 30.0,
|
153
|
+
) -> tuple[list[str], list[dict[str, float]]]:
|
154
|
+
"""
|
155
|
+
MCTS planning for Pokemon Red.
|
156
|
+
Returns (action_plan, q_value_history)
|
157
|
+
"""
|
158
|
+
start = time.monotonic()
|
159
|
+
plan, q_hist, node_id = [], [], root_id
|
160
|
+
|
161
|
+
for depth in range(max_depth):
|
162
|
+
LOG.debug(f"\n--- MCTS depth {depth} --- node={node_id[:8]}")
|
163
|
+
|
164
|
+
if timeout_s is not None and time.monotonic() - start >= timeout_s:
|
165
|
+
LOG.debug("MCTS timeout reached")
|
166
|
+
break
|
167
|
+
|
168
|
+
# Load environment from snapshot
|
169
|
+
env_blob = tree.load_snapshot_blob(node_id)
|
170
|
+
env = await PokemonRedEnvironment._deserialize_engine(
|
171
|
+
pickle.loads(gzip.decompress(env_blob)), DEFAULT_TASK
|
172
|
+
)
|
173
|
+
|
174
|
+
# Check if terminal
|
175
|
+
if is_terminal_state(env):
|
176
|
+
LOG.debug("Terminal state reached")
|
177
|
+
break
|
178
|
+
|
179
|
+
# Log current state
|
180
|
+
priv, pub = env.engine._create_states(reward=0.0)
|
181
|
+
LOG.debug(
|
182
|
+
f"State: Map{pub.map_id:02X}:({pub.player_x},{pub.player_y}) "
|
183
|
+
f"Badges:{bin(pub.badges).count('1')} Level:{pub.party_level} "
|
184
|
+
f"HP:{pub.party_hp_current}/{pub.party_hp_max}"
|
185
|
+
)
|
186
|
+
|
187
|
+
q_vals: dict[str, float] = {}
|
188
|
+
|
189
|
+
# Evaluate each possible action
|
190
|
+
for action in POKEMON_ACTIONS:
|
191
|
+
if timeout_s is not None and time.monotonic() - start >= timeout_s:
|
192
|
+
break
|
193
|
+
|
194
|
+
# Check if we already have a child for this action
|
195
|
+
child_id = next(
|
196
|
+
(
|
197
|
+
cid
|
198
|
+
for cid in tree.get_children(node_id)
|
199
|
+
if tree.graph[node_id][cid]["action"] == action
|
200
|
+
),
|
201
|
+
None,
|
202
|
+
)
|
203
|
+
|
204
|
+
if child_id is None: # Need to expand
|
205
|
+
LOG.debug(f"Expanding action: {action}")
|
206
|
+
|
207
|
+
# Create new environment and take action
|
208
|
+
try:
|
209
|
+
tmp_env = await PokemonRedEnvironment._deserialize_engine(
|
210
|
+
pickle.loads(gzip.decompress(env_blob)), DEFAULT_TASK
|
211
|
+
)
|
212
|
+
|
213
|
+
call = EnvToolCall(tool="press_button", args={"button": action, "frames": 1})
|
214
|
+
await tmp_env.step(call)
|
215
|
+
|
216
|
+
# Create child node
|
217
|
+
child_blob = gzip.compress(pickle.dumps(await tmp_env._serialize_engine()))
|
218
|
+
child_id = tree.add_child(
|
219
|
+
node_id,
|
220
|
+
child_blob,
|
221
|
+
action=action,
|
222
|
+
reward=heuristic_score(tmp_env),
|
223
|
+
terminated=is_terminal_state(tmp_env),
|
224
|
+
info={},
|
225
|
+
)
|
226
|
+
|
227
|
+
except Exception as e:
|
228
|
+
LOG.debug(f"Failed to expand action {action}: {e}")
|
229
|
+
continue
|
230
|
+
else:
|
231
|
+
LOG.debug(f"Reusing existing child for action: {action}")
|
232
|
+
|
233
|
+
if child_id is None:
|
234
|
+
continue
|
235
|
+
|
236
|
+
# Perform rollouts from child state
|
237
|
+
child_env = await PokemonRedEnvironment._deserialize_engine(
|
238
|
+
pickle.loads(gzip.decompress(tree.load_snapshot_blob(child_id))),
|
239
|
+
DEFAULT_TASK,
|
240
|
+
)
|
241
|
+
|
242
|
+
rollout_scores = []
|
243
|
+
for _ in range(rollouts_per_action):
|
244
|
+
if timeout_s is not None and time.monotonic() - start >= timeout_s:
|
245
|
+
break
|
246
|
+
score = await simple_rollout(child_env, max_steps=10)
|
247
|
+
rollout_scores.append(score)
|
248
|
+
|
249
|
+
if rollout_scores:
|
250
|
+
# Average rollout score as Q-value
|
251
|
+
q_vals[action] = sum(rollout_scores) / len(rollout_scores)
|
252
|
+
LOG.debug(
|
253
|
+
f"Action {action}: Q={q_vals[action]:.3f} "
|
254
|
+
f"(avg of {len(rollout_scores)} rollouts)"
|
255
|
+
)
|
256
|
+
else:
|
257
|
+
q_vals[action] = 0.0
|
258
|
+
|
259
|
+
if not q_vals:
|
260
|
+
LOG.debug("No valid actions found")
|
261
|
+
break
|
262
|
+
|
263
|
+
LOG.debug(f"Q-values: {q_vals}")
|
264
|
+
q_hist.append(q_vals)
|
265
|
+
|
266
|
+
# Select best action
|
267
|
+
best_action = max(q_vals, key=q_vals.get)
|
268
|
+
plan.append(best_action)
|
269
|
+
|
270
|
+
# Move to child node
|
271
|
+
child_nodes = tree.get_children(node_id)
|
272
|
+
next_node = None
|
273
|
+
for cid in child_nodes:
|
274
|
+
if tree.graph[node_id][cid]["action"] == best_action:
|
275
|
+
next_node = cid
|
276
|
+
break
|
277
|
+
|
278
|
+
if next_node is None:
|
279
|
+
LOG.debug(f"No child node found for action {best_action}")
|
280
|
+
break
|
281
|
+
|
282
|
+
node_id = next_node
|
283
|
+
|
284
|
+
LOG.debug(f"Selected action: {best_action} → node={node_id[:8]}")
|
285
|
+
|
286
|
+
return plan, q_hist
|
287
|
+
|
288
|
+
|
289
|
+
@pytest.mark.asyncio
|
290
|
+
async def test_mcts_pokemon_red_basic(tmp_path: Path) -> None:
|
291
|
+
"""Test basic MCTS functionality with Pokemon Red"""
|
292
|
+
|
293
|
+
# Create environment
|
294
|
+
env = PokemonRedEnvironment(DEFAULT_TASK)
|
295
|
+
await env.initialize()
|
296
|
+
|
297
|
+
# Set up tree storage
|
298
|
+
snap_store_path = tmp_path / "pokemon_mcts_snaps"
|
299
|
+
tree = TrajectoryTreeStore(FilesystemSnapshotStore(snap_store_path))
|
300
|
+
|
301
|
+
# Add root snapshot
|
302
|
+
root_blob = gzip.compress(pickle.dumps(await env._serialize_engine()))
|
303
|
+
root_id = tree.add_root(root_blob)
|
304
|
+
|
305
|
+
LOG.debug("Starting Pokemon Red MCTS planning...")
|
306
|
+
|
307
|
+
# Run MCTS with short timeout for testing
|
308
|
+
plan, q_hist = await pokemon_red_mcts_plan(
|
309
|
+
tree,
|
310
|
+
root_id,
|
311
|
+
rollouts_per_action=3, # Reduced for faster testing
|
312
|
+
max_depth=5, # Shallow depth for testing
|
313
|
+
timeout_s=10.0, # Short timeout
|
314
|
+
)
|
315
|
+
|
316
|
+
print(f"MCTS Plan: {plan}")
|
317
|
+
print(f"Q-value history: {q_hist}")
|
318
|
+
|
319
|
+
# Verify we got some plan
|
320
|
+
assert isinstance(plan, list), "Plan should be a list"
|
321
|
+
assert len(plan) >= 0, "Plan should have non-negative length"
|
322
|
+
|
323
|
+
# Verify all actions in plan are valid
|
324
|
+
for action in plan:
|
325
|
+
assert action in POKEMON_ACTIONS, f"Invalid action in plan: {action}"
|
326
|
+
|
327
|
+
# Verify Q-values were computed
|
328
|
+
assert isinstance(q_hist, list), "Q-history should be a list"
|
329
|
+
for q_dict in q_hist:
|
330
|
+
assert isinstance(q_dict, dict), "Each Q-value entry should be a dict"
|
331
|
+
for action, q_val in q_dict.items():
|
332
|
+
assert action in POKEMON_ACTIONS, f"Invalid action in Q-values: {action}"
|
333
|
+
assert isinstance(q_val, (int, float)), f"Q-value should be numeric: {q_val}"
|
334
|
+
|
335
|
+
|
336
|
+
@pytest.mark.asyncio
|
337
|
+
async def test_mcts_pokemon_red_execution(tmp_path: Path) -> None:
|
338
|
+
"""Test that MCTS plan can be executed in Pokemon Red"""
|
339
|
+
|
340
|
+
# Create environment
|
341
|
+
env = PokemonRedEnvironment(DEFAULT_TASK)
|
342
|
+
await env.initialize()
|
343
|
+
|
344
|
+
# Get initial state for comparison
|
345
|
+
initial_priv, initial_pub = env.engine._create_states(reward=0.0)
|
346
|
+
initial_score = heuristic_score(env)
|
347
|
+
|
348
|
+
LOG.debug(
|
349
|
+
f"Initial state - Score: {initial_score:.3f}, "
|
350
|
+
f"Map: {initial_pub.map_id}, Pos: ({initial_pub.player_x},{initial_pub.player_y}), "
|
351
|
+
f"Level: {initial_pub.party_level}, Badges: {bin(initial_pub.badges).count('1')}"
|
352
|
+
)
|
353
|
+
|
354
|
+
# Set up MCTS
|
355
|
+
snap_store_path = tmp_path / "pokemon_execution_test"
|
356
|
+
tree = TrajectoryTreeStore(FilesystemSnapshotStore(snap_store_path))
|
357
|
+
root_blob = gzip.compress(pickle.dumps(await env._serialize_engine()))
|
358
|
+
root_id = tree.add_root(root_blob)
|
359
|
+
|
360
|
+
# Run MCTS
|
361
|
+
plan, q_hist = await pokemon_red_mcts_plan(
|
362
|
+
tree, root_id, rollouts_per_action=2, max_depth=3, timeout_s=8.0
|
363
|
+
)
|
364
|
+
|
365
|
+
# Execute the plan
|
366
|
+
for i, action in enumerate(plan):
|
367
|
+
LOG.debug(f"Executing step {i + 1}: {action}")
|
368
|
+
call = EnvToolCall(tool="press_button", args={"button": action, "frames": 1})
|
369
|
+
obs = await env.step(call)
|
370
|
+
|
371
|
+
# Log progress
|
372
|
+
LOG.debug(f" → Step {obs['step_count']}, Reward: {obs['total_reward']:.3f}")
|
373
|
+
|
374
|
+
# Check final state
|
375
|
+
final_priv, final_pub = env.engine._create_states(reward=0.0)
|
376
|
+
final_score = heuristic_score(env)
|
377
|
+
|
378
|
+
LOG.debug(
|
379
|
+
f"Final state - Score: {final_score:.3f}, "
|
380
|
+
f"Map: {final_pub.map_id}, Pos: ({final_pub.player_x},{final_pub.player_y}), "
|
381
|
+
f"Level: {final_pub.party_level}, Badges: {bin(final_pub.badges).count('1')}"
|
382
|
+
)
|
383
|
+
|
384
|
+
# Verify execution worked
|
385
|
+
assert final_pub.step_count >= len(plan), "Steps should have been executed"
|
386
|
+
|
387
|
+
# Verify some progress was made (even if minimal)
|
388
|
+
progress_made = (
|
389
|
+
final_pub.map_id != initial_pub.map_id
|
390
|
+
or final_pub.player_x != initial_pub.player_x
|
391
|
+
or final_pub.player_y != initial_pub.player_y
|
392
|
+
or final_pub.party_level > initial_pub.party_level
|
393
|
+
or final_pub.badges != initial_pub.badges
|
394
|
+
or abs(final_score - initial_score) > 0.01
|
395
|
+
)
|
396
|
+
|
397
|
+
LOG.debug(f"Progress made: {progress_made}")
|
398
|
+
# Note: Progress isn't guaranteed in a short test, so we just verify execution worked
|
399
|
+
|
400
|
+
|
401
|
+
@pytest.mark.asyncio
|
402
|
+
async def test_heuristic_functions() -> None:
|
403
|
+
"""Test the heuristic and utility functions"""
|
404
|
+
|
405
|
+
# Create test environment
|
406
|
+
env = PokemonRedEnvironment(DEFAULT_TASK)
|
407
|
+
await env.initialize()
|
408
|
+
|
409
|
+
# Test heuristic scoring
|
410
|
+
initial_score = heuristic_score(env)
|
411
|
+
assert isinstance(initial_score, (int, float)), "Heuristic should return numeric score"
|
412
|
+
assert initial_score >= 0, "Initial score should be non-negative"
|
413
|
+
|
414
|
+
# Test terminal state detection
|
415
|
+
is_terminal = is_terminal_state(env)
|
416
|
+
assert isinstance(is_terminal, bool), "Terminal check should return boolean"
|
417
|
+
|
418
|
+
# Test rollout (with very short length)
|
419
|
+
rollout_score = await simple_rollout(env, max_steps=3)
|
420
|
+
assert isinstance(rollout_score, (int, float)), "Rollout should return numeric score"
|
421
|
+
|
422
|
+
LOG.debug(
|
423
|
+
f"Heuristic tests - Initial: {initial_score:.3f}, "
|
424
|
+
f"Terminal: {is_terminal}, Rollout: {rollout_score:.3f}"
|
425
|
+
)
|
426
|
+
|
427
|
+
|
428
|
+
if __name__ == "__main__":
|
429
|
+
import tempfile
|
430
|
+
|
431
|
+
async def main():
|
432
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
433
|
+
tmp_path = Path(tmpdir)
|
434
|
+
|
435
|
+
print("Running Pokemon Red MCTS tests...")
|
436
|
+
|
437
|
+
await test_heuristic_functions()
|
438
|
+
print("✓ Heuristic functions test passed")
|
439
|
+
|
440
|
+
await test_mcts_pokemon_red_basic(tmp_path)
|
441
|
+
print("✓ Basic MCTS test passed")
|
442
|
+
|
443
|
+
await test_mcts_pokemon_red_execution(tmp_path)
|
444
|
+
print("✓ MCTS execution test passed")
|
445
|
+
|
446
|
+
print("🎉 All Pokemon Red MCTS tests passed!")
|
447
|
+
|
448
|
+
asyncio.run(main())
|
@@ -0,0 +1 @@
|
|
1
|
+
"""Sokoban environment example."""
|