synth-ai 0.2.0__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +35 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +1 -1
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.2.0.dist-info/METADATA +0 -36
- synth_ai-0.2.0.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info/licenses}/LICENSE +0 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1471 @@
|
|
1
|
+
import asyncio
|
2
|
+
import uuid
|
3
|
+
import pytest
|
4
|
+
import json
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Dict, Any, List, Optional, Deque, Literal
|
7
|
+
from pydantic import BaseModel, Field, validator
|
8
|
+
from collections import deque
|
9
|
+
from synth_ai.zyk import LM
|
10
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
11
|
+
from synth_sdk.tracing.decorators import trace_event_async
|
12
|
+
from synth_sdk.tracing.abstractions import RewardSignal, Dataset, TrainingQuestion
|
13
|
+
from synth_sdk.tracing.utils import get_system_id
|
14
|
+
|
15
|
+
# Monkey patch the zyk cache handler to allow mixed content types (for images)
|
16
|
+
try:
|
17
|
+
from synth_ai.zyk.lms.caching.handler import CacheHandler
|
18
|
+
|
19
|
+
original_validate_messages = CacheHandler._validate_messages
|
20
|
+
|
21
|
+
def patched_validate_messages(self, messages: List[Dict[str, Any]]) -> None:
|
22
|
+
"""Validate that messages are in the correct format - PATCHED to allow mixed content for images."""
|
23
|
+
# Allow mixed content types when images are involved - just check that messages exist
|
24
|
+
assert all(isinstance(msg, dict) and "content" in msg for msg in messages), (
|
25
|
+
"All messages must be dicts with content"
|
26
|
+
)
|
27
|
+
|
28
|
+
CacheHandler._validate_messages = patched_validate_messages
|
29
|
+
print("[DEBUG] Successfully monkey patched zyk cache validation to support images")
|
30
|
+
except Exception as e:
|
31
|
+
print(f"[DEBUG] Failed to monkey patch zyk cache validation: {e}")
|
32
|
+
# Continue anyway - the assertion might not be hit in all cases
|
33
|
+
|
34
|
+
# Pokemon Red specific imports
|
35
|
+
from synth_ai.environments.examples.red.environment import (
|
36
|
+
PokemonRedEnvironment,
|
37
|
+
PokemonRedPublicState,
|
38
|
+
PokemonRedPrivateState,
|
39
|
+
)
|
40
|
+
|
41
|
+
# Import early game reward components
|
42
|
+
from synth_ai.environments.examples.red.engine_helpers.reward_library.pallet_town_rewards import (
|
43
|
+
LeaveStartingRoomReward,
|
44
|
+
TalkToMomReward,
|
45
|
+
InteractWithTVReward,
|
46
|
+
CheckComputerReward,
|
47
|
+
ExitHouseReward,
|
48
|
+
ExploreTownReward,
|
49
|
+
TalkToNPCsReward,
|
50
|
+
OakLabDiscoveryReward,
|
51
|
+
AttemptRoute1Reward,
|
52
|
+
ChooseStarterPokemonReward,
|
53
|
+
DoorInteractionReward,
|
54
|
+
ObjectInteractionReward,
|
55
|
+
TryAllDirectionsReward,
|
56
|
+
)
|
57
|
+
from synth_ai.environments.examples.red.engine_helpers.reward_library.exploration_rewards import (
|
58
|
+
NewAreaDiscoveryReward,
|
59
|
+
BuildingEntryReward,
|
60
|
+
)
|
61
|
+
from synth_ai.environments.examples.red.engine_helpers.reward_library.novelty_rewards import (
|
62
|
+
FirstBattleReward,
|
63
|
+
FirstPokemonCenterVisitReward,
|
64
|
+
)
|
65
|
+
|
66
|
+
from synth_ai.environments.environment.shared_engine import (
|
67
|
+
GetObservationCallable,
|
68
|
+
InternalObservation,
|
69
|
+
)
|
70
|
+
from synth_ai.environments.examples.red.taskset import PokemonRedTaskInstance
|
71
|
+
from synth_ai.environments.tasks.core import Impetus, Intent, TaskInstanceMetadata
|
72
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
73
|
+
|
74
|
+
# Import screen analysis functions
|
75
|
+
from synth_ai.environments.examples.red.engine_helpers.screen_analysis import (
|
76
|
+
analyze_screen_buffer,
|
77
|
+
create_detailed_screen_description,
|
78
|
+
)
|
79
|
+
|
80
|
+
import logging
|
81
|
+
|
82
|
+
logging.disable(logging.CRITICAL)
|
83
|
+
|
84
|
+
|
85
|
+
# --- Early Game Reward Manager ---
|
86
|
+
class EarlyGameRewardManager:
|
87
|
+
"""Manages early game rewards for Pokemon Red to encourage exploration and progress"""
|
88
|
+
|
89
|
+
def __init__(self):
|
90
|
+
# Initialize early game reward components
|
91
|
+
self.rewards = [
|
92
|
+
# Pallet Town house exploration
|
93
|
+
LeaveStartingRoomReward(),
|
94
|
+
TalkToMomReward(),
|
95
|
+
InteractWithTVReward(),
|
96
|
+
CheckComputerReward(),
|
97
|
+
ExitHouseReward(),
|
98
|
+
# Town and building exploration
|
99
|
+
ExploreTownReward(),
|
100
|
+
TalkToNPCsReward(),
|
101
|
+
NewAreaDiscoveryReward(),
|
102
|
+
BuildingEntryReward(),
|
103
|
+
# Story progression
|
104
|
+
OakLabDiscoveryReward(),
|
105
|
+
AttemptRoute1Reward(),
|
106
|
+
ChooseStarterPokemonReward(),
|
107
|
+
# Basic interactions
|
108
|
+
DoorInteractionReward(),
|
109
|
+
ObjectInteractionReward(),
|
110
|
+
TryAllDirectionsReward(),
|
111
|
+
# First time experiences
|
112
|
+
FirstBattleReward(),
|
113
|
+
FirstPokemonCenterVisitReward(),
|
114
|
+
]
|
115
|
+
|
116
|
+
self.total_reward_earned = 0.0
|
117
|
+
self.reward_history = []
|
118
|
+
|
119
|
+
async def calculate_rewards(
|
120
|
+
self,
|
121
|
+
current_state: Dict[str, Any],
|
122
|
+
prev_state: Dict[str, Any],
|
123
|
+
action_info: Dict[str, Any],
|
124
|
+
) -> float:
|
125
|
+
"""Calculate rewards for the current state transition"""
|
126
|
+
total_reward = 0.0
|
127
|
+
step_rewards = []
|
128
|
+
|
129
|
+
# Create action context with previous state info
|
130
|
+
action_context = {
|
131
|
+
"prev_map_id": prev_state.get("map_id", -1),
|
132
|
+
"prev_player_x": prev_state.get("player_x", -1),
|
133
|
+
"prev_player_y": prev_state.get("player_y", -1),
|
134
|
+
"prev_text_box_active": prev_state.get("text_box_active", False),
|
135
|
+
"prev_in_battle": prev_state.get("in_battle", False),
|
136
|
+
"prev_party": prev_state.get("party", []),
|
137
|
+
"prev_inventory": prev_state.get("inventory", []),
|
138
|
+
"prev_money": prev_state.get("money", 0),
|
139
|
+
**action_info, # Include any additional action info
|
140
|
+
}
|
141
|
+
|
142
|
+
# Calculate rewards from each component
|
143
|
+
for reward_component in self.rewards:
|
144
|
+
try:
|
145
|
+
reward = await reward_component.score(current_state, action_context)
|
146
|
+
if reward > 0:
|
147
|
+
total_reward += reward
|
148
|
+
step_rewards.append(
|
149
|
+
{
|
150
|
+
"component": reward_component.__class__.__name__,
|
151
|
+
"reward": reward,
|
152
|
+
}
|
153
|
+
)
|
154
|
+
print(f"[REWARD] {reward_component.__class__.__name__}: +{reward:.1f}")
|
155
|
+
except Exception as e:
|
156
|
+
print(f"[REWARD_ERROR] {reward_component.__class__.__name__}: {e}")
|
157
|
+
continue
|
158
|
+
|
159
|
+
if total_reward > 0:
|
160
|
+
self.total_reward_earned += total_reward
|
161
|
+
self.reward_history.append(
|
162
|
+
{
|
163
|
+
"step": current_state.get("step_count", 0),
|
164
|
+
"total_reward": total_reward,
|
165
|
+
"components": step_rewards,
|
166
|
+
}
|
167
|
+
)
|
168
|
+
print(
|
169
|
+
f"[REWARD_TOTAL] Step {current_state.get('step_count', 0)}: +{total_reward:.1f} (Total: {self.total_reward_earned:.1f})"
|
170
|
+
)
|
171
|
+
|
172
|
+
return total_reward
|
173
|
+
|
174
|
+
|
175
|
+
# --- Helper function to format observation for LLM ---
|
176
|
+
def format_obs_for_llm_from_states(
|
177
|
+
pub: PokemonRedPublicState,
|
178
|
+
priv: PokemonRedPrivateState,
|
179
|
+
screen_analysis: dict = None,
|
180
|
+
mode: str = "state_and_screen",
|
181
|
+
) -> str:
|
182
|
+
"""Format Pokemon Red observation for LLM consumption with comprehensive text-based state information.
|
183
|
+
|
184
|
+
This function provides rich, semantic game state information to eliminate
|
185
|
+
the need for visual processing, as specified in text_port.txt requirements.
|
186
|
+
"""
|
187
|
+
|
188
|
+
obs_lines = [
|
189
|
+
"=== POKEMON RED GAME STATE ===",
|
190
|
+
f"Step: {pub.progress.step_count}",
|
191
|
+
]
|
192
|
+
|
193
|
+
# === VISUAL SCREEN INFORMATION ===
|
194
|
+
if screen_analysis:
|
195
|
+
obs_lines.extend(["", "=== VISUAL SCREEN ANALYSIS ==="])
|
196
|
+
|
197
|
+
# Add detailed screen description - only include ASCII for state_and_ascii mode
|
198
|
+
if mode == "state_and_ascii":
|
199
|
+
screen_description = create_detailed_screen_description(screen_analysis)
|
200
|
+
else:
|
201
|
+
# For state_and_screen mode, show summary without ASCII
|
202
|
+
screen_description = f"SCREEN TYPE: {screen_analysis.get('screen_type', 'UNKNOWN')}\n"
|
203
|
+
|
204
|
+
# Add color analysis
|
205
|
+
if "colors" in screen_analysis:
|
206
|
+
colors_text = "DOMINANT COLORS: " + ", ".join(
|
207
|
+
[f"{color}({pct}%)" for color, pct in screen_analysis["colors"].items()]
|
208
|
+
)
|
209
|
+
screen_description += colors_text + "\n"
|
210
|
+
|
211
|
+
# Add entity detection summary
|
212
|
+
if "entities" in screen_analysis:
|
213
|
+
screen_description += (
|
214
|
+
f"DETECTED ENTITIES: {len(screen_analysis['entities'])} sprite-like objects\n"
|
215
|
+
)
|
216
|
+
|
217
|
+
# Add UI elements
|
218
|
+
if "ui_elements" in screen_analysis:
|
219
|
+
ui_elements = screen_analysis["ui_elements"]
|
220
|
+
if ui_elements:
|
221
|
+
screen_description += f"UI: {', '.join(ui_elements)} detected\n"
|
222
|
+
|
223
|
+
obs_lines.append(screen_description)
|
224
|
+
|
225
|
+
# === WORLD INFORMATION ===
|
226
|
+
obs_lines.extend(
|
227
|
+
[
|
228
|
+
"",
|
229
|
+
"=== WORLD LOCATION ===",
|
230
|
+
f"Map ID: {pub.world.map_id} | Position: ({pub.world.player_x}, {pub.world.player_y})",
|
231
|
+
]
|
232
|
+
)
|
233
|
+
|
234
|
+
# === PLAYER PROGRESS ===
|
235
|
+
obs_lines.extend(
|
236
|
+
[
|
237
|
+
"",
|
238
|
+
"=== PLAYER PROGRESS ===",
|
239
|
+
f"Badges: {pub.progress.badge_count}/8 (0x{pub.progress.badges:02X})",
|
240
|
+
f"Money: ${pub.progress.money:,}",
|
241
|
+
]
|
242
|
+
)
|
243
|
+
|
244
|
+
# === POKEMON PARTY ===
|
245
|
+
obs_lines.extend(["", "=== POKEMON PARTY ==="])
|
246
|
+
|
247
|
+
if pub.party:
|
248
|
+
for i, pokemon in enumerate(pub.party, 1):
|
249
|
+
status_icon = "●" if pokemon.hp_current > 0 else "✗"
|
250
|
+
obs_lines.append(
|
251
|
+
f"{i}. Species#{pokemon.species_id:03d} L{pokemon.level} | "
|
252
|
+
f"HP:{pokemon.hp_current}/{pokemon.hp_max} ({pokemon.hp_percentage:.1f}%) {status_icon} | "
|
253
|
+
f"XP:{pokemon.xp:,}"
|
254
|
+
)
|
255
|
+
else:
|
256
|
+
obs_lines.append("No Pokemon in party")
|
257
|
+
|
258
|
+
# === INVENTORY ===
|
259
|
+
obs_lines.extend(["", "=== INVENTORY ==="])
|
260
|
+
|
261
|
+
if pub.inventory:
|
262
|
+
# Show first 8 items with quantities
|
263
|
+
for item in pub.inventory[:8]:
|
264
|
+
obs_lines.append(f"Item#{item.item_id:03d} x{item.quantity}")
|
265
|
+
|
266
|
+
if len(pub.inventory) > 8:
|
267
|
+
obs_lines.append(f"... and {len(pub.inventory) - 8} more items")
|
268
|
+
|
269
|
+
obs_lines.append(f"Total Items: {len(pub.inventory)}")
|
270
|
+
else:
|
271
|
+
obs_lines.append("No items in inventory")
|
272
|
+
|
273
|
+
# === GAME SYSTEM STATE ===
|
274
|
+
obs_lines.extend(["", "=== GAME SYSTEM STATE ==="])
|
275
|
+
|
276
|
+
# Just show raw state without interpretation
|
277
|
+
if pub.system.in_battle:
|
278
|
+
obs_lines.append("In Battle: True")
|
279
|
+
obs_lines.append(f"Battle Outcome: {pub.system.battle_outcome}")
|
280
|
+
else:
|
281
|
+
obs_lines.append("In Battle: False")
|
282
|
+
|
283
|
+
if pub.system.text_box_active:
|
284
|
+
obs_lines.append("Text Box Active: True")
|
285
|
+
else:
|
286
|
+
obs_lines.append("Text Box Active: False")
|
287
|
+
|
288
|
+
obs_lines.append(f"Warp Flag: {pub.system.warp_flag}")
|
289
|
+
|
290
|
+
# === TECHNICAL INFO ===
|
291
|
+
obs_lines.extend(
|
292
|
+
[
|
293
|
+
"",
|
294
|
+
"=== TECHNICAL INFO ===",
|
295
|
+
f"Last Reward: {priv.reward_last_step:.3f}",
|
296
|
+
f"Total Reward: {priv.total_reward:.3f}",
|
297
|
+
f"Terminated: {priv.terminated} | Truncated: {priv.truncated}",
|
298
|
+
]
|
299
|
+
)
|
300
|
+
|
301
|
+
if pub.error_info:
|
302
|
+
obs_lines.append(f"Error: {pub.error_info}")
|
303
|
+
|
304
|
+
obs_lines.append("=== END GAME STATE ===")
|
305
|
+
|
306
|
+
return "\n".join(obs_lines)
|
307
|
+
|
308
|
+
|
309
|
+
# --- Custom observation callable for Pokemon Red ---
|
310
|
+
class PokemonRedHistoryObservationCallable(GetObservationCallable):
|
311
|
+
def __init__(
|
312
|
+
self,
|
313
|
+
max_history: int = 1,
|
314
|
+
mode: Literal["state_and_ascii", "state_and_screen"] = "state_and_screen",
|
315
|
+
):
|
316
|
+
self._hist_obs: Deque[str] = deque(maxlen=max_history)
|
317
|
+
self._hist_pub_state: Deque[PokemonRedPublicState] = deque(maxlen=max_history)
|
318
|
+
self._hist_priv_state: Deque[PokemonRedPrivateState] = deque(maxlen=max_history)
|
319
|
+
self._last_state_hash = None
|
320
|
+
self._stuck_count = 0
|
321
|
+
self.screen_buffer = None # Store screen buffer for agent access
|
322
|
+
self.mode = mode # Store mode for observation formatting
|
323
|
+
|
324
|
+
# Initialize reward manager for early game rewards
|
325
|
+
self.reward_manager = EarlyGameRewardManager()
|
326
|
+
self._last_state_dict = None # Store previous state for reward calculation
|
327
|
+
|
328
|
+
async def get_observation(
|
329
|
+
self, pub: PokemonRedPublicState, priv: PokemonRedPrivateState
|
330
|
+
) -> InternalObservation:
|
331
|
+
if pub is None or priv is None:
|
332
|
+
raise RuntimeError("Missing public or private state in get_observation - HARD FAIL")
|
333
|
+
|
334
|
+
# Create current state dict for reward calculation
|
335
|
+
current_state_dict = {
|
336
|
+
"map_id": pub.map_id,
|
337
|
+
"player_x": pub.player_x,
|
338
|
+
"player_y": pub.player_y,
|
339
|
+
"step_count": pub.step_count,
|
340
|
+
"text_box_active": pub.system.text_box_active,
|
341
|
+
"in_battle": pub.system.in_battle,
|
342
|
+
"party": [
|
343
|
+
{
|
344
|
+
"species_id": p.species_id,
|
345
|
+
"level": p.level,
|
346
|
+
"hp_current": p.hp_current,
|
347
|
+
"hp_max": p.hp_max,
|
348
|
+
}
|
349
|
+
for p in pub.party
|
350
|
+
],
|
351
|
+
"inventory": [
|
352
|
+
{"item_id": item.item_id, "quantity": item.quantity} for item in pub.inventory
|
353
|
+
],
|
354
|
+
"money": pub.progress.money,
|
355
|
+
"badges": pub.progress.badges,
|
356
|
+
}
|
357
|
+
|
358
|
+
# Calculate rewards if we have a previous state
|
359
|
+
additional_reward = 0.0
|
360
|
+
if self._last_state_dict is not None:
|
361
|
+
try:
|
362
|
+
additional_reward = await self.reward_manager.calculate_rewards(
|
363
|
+
current_state_dict,
|
364
|
+
self._last_state_dict,
|
365
|
+
{"buttons_pressed": []}, # Could track actual buttons if needed
|
366
|
+
)
|
367
|
+
except Exception as e:
|
368
|
+
print(f"[REWARD_ERROR] Failed to calculate rewards: {e}")
|
369
|
+
|
370
|
+
# Store current state for next iteration
|
371
|
+
self._last_state_dict = current_state_dict.copy()
|
372
|
+
|
373
|
+
# Check if we're stuck (same position and menu state for multiple steps)
|
374
|
+
# Use property accessors that handle the new state structure
|
375
|
+
current_state_hash = hash((pub.player_x, pub.player_y, pub.map_id, pub.step_count))
|
376
|
+
if self._last_state_hash == current_state_hash and pub.step_count > 1:
|
377
|
+
self._stuck_count += 1
|
378
|
+
if self._stuck_count >= 3:
|
379
|
+
raise RuntimeError(
|
380
|
+
f"Agent stuck in same state for {self._stuck_count} steps - HARD FAIL. Position: ({pub.player_x}, {pub.player_y}), Map: {pub.map_id}"
|
381
|
+
)
|
382
|
+
else:
|
383
|
+
self._stuck_count = 0
|
384
|
+
self._last_state_hash = current_state_hash
|
385
|
+
|
386
|
+
# Extract screen buffer for agent vision - FAIL HARD if screen access doesn't work
|
387
|
+
additional_context = ""
|
388
|
+
screen_analysis = None
|
389
|
+
|
390
|
+
try:
|
391
|
+
# Look for environment in call stack to access engine/emulator
|
392
|
+
import inspect
|
393
|
+
|
394
|
+
frame = inspect.currentframe()
|
395
|
+
env = None
|
396
|
+
|
397
|
+
# Walk up the call stack to find the environment
|
398
|
+
while frame:
|
399
|
+
if "self" in frame.f_locals and hasattr(frame.f_locals["self"], "engine"):
|
400
|
+
env = frame.f_locals["self"]
|
401
|
+
break
|
402
|
+
frame = frame.f_back
|
403
|
+
|
404
|
+
if not env or not hasattr(env, "engine") or not env.engine:
|
405
|
+
raise RuntimeError("Cannot access environment engine - HARD FAIL")
|
406
|
+
|
407
|
+
# REQUIRE screen access to work
|
408
|
+
if not hasattr(env.engine, "emulator") or not env.engine.emulator:
|
409
|
+
raise RuntimeError("Emulator not available - HARD FAIL")
|
410
|
+
|
411
|
+
if not hasattr(env.engine.emulator, "screen"):
|
412
|
+
raise RuntimeError("Emulator screen not available - HARD FAIL")
|
413
|
+
|
414
|
+
# Use PyBoy's documented screen.ndarray property - shape (144, 160, 4) RGBA
|
415
|
+
screen_buffer = (
|
416
|
+
env.engine.emulator.screen.ndarray.copy()
|
417
|
+
) # Copy to avoid reference issues
|
418
|
+
|
419
|
+
if screen_buffer is None:
|
420
|
+
raise RuntimeError("Screen ndarray is None - HARD FAIL")
|
421
|
+
|
422
|
+
# Store screen buffer for agent to access
|
423
|
+
self.screen_buffer = screen_buffer
|
424
|
+
print(f"[DEBUG] Successfully extracted screen buffer with shape: {screen_buffer.shape}")
|
425
|
+
|
426
|
+
# Perform detailed screen analysis
|
427
|
+
screen_analysis = analyze_screen_buffer(screen_buffer)
|
428
|
+
print(
|
429
|
+
f"[DEBUG] Screen analysis completed - type: {screen_analysis.get('screen_type', 'UNKNOWN')}"
|
430
|
+
)
|
431
|
+
|
432
|
+
# Get additional game state context - REQUIRE this to work
|
433
|
+
current_state = env.engine._extract_current_state()
|
434
|
+
if not current_state:
|
435
|
+
raise RuntimeError("Failed to extract game state - HARD FAIL")
|
436
|
+
|
437
|
+
# Use the new structured state information from the public state
|
438
|
+
additional_context += f"\nWarp Flag: {pub.system.warp_flag}"
|
439
|
+
additional_context += f"\nBattle Outcome: {pub.system.battle_outcome}"
|
440
|
+
additional_context += f"\nInventory Count: {len(pub.inventory)}"
|
441
|
+
|
442
|
+
except Exception as e:
|
443
|
+
# HARD FAIL on any screen/context extraction errors
|
444
|
+
raise RuntimeError(f"Screen/context extraction HARD FAIL: {e}")
|
445
|
+
|
446
|
+
# Format the base observation with screen analysis
|
447
|
+
if self.mode == "state_and_ascii":
|
448
|
+
# Include ASCII analysis but no screen buffer in observation
|
449
|
+
formatted_obs = format_obs_for_llm_from_states(pub, priv, screen_analysis, self.mode)
|
450
|
+
else:
|
451
|
+
# Include screen analysis for screen mode
|
452
|
+
formatted_obs = format_obs_for_llm_from_states(pub, priv, screen_analysis, self.mode)
|
453
|
+
|
454
|
+
# Add context info
|
455
|
+
enhanced_obs = formatted_obs.replace(
|
456
|
+
"\n=== END GAME STATE ===", f"{additional_context}\n=== END GAME STATE ==="
|
457
|
+
)
|
458
|
+
|
459
|
+
# Add reward information to the observation
|
460
|
+
if additional_reward > 0 or self.reward_manager.total_reward_earned > 0:
|
461
|
+
reward_info = "\n\n=== REWARD PROGRESS ===\n"
|
462
|
+
if additional_reward > 0:
|
463
|
+
reward_info += f"Step Reward: +{additional_reward:.1f}\n"
|
464
|
+
reward_info += f"Total Rewards Earned: {self.reward_manager.total_reward_earned:.1f}\n"
|
465
|
+
|
466
|
+
# Show recent reward achievements (last 3)
|
467
|
+
if self.reward_manager.reward_history:
|
468
|
+
reward_info += "Recent Achievements:\n"
|
469
|
+
for achievement in self.reward_manager.reward_history[-3:]:
|
470
|
+
for component in achievement["components"]:
|
471
|
+
reward_info += f"• {component['component']}: +{component['reward']:.1f}\n"
|
472
|
+
|
473
|
+
enhanced_obs = enhanced_obs.replace(
|
474
|
+
"\n=== END GAME STATE ===", f"{reward_info}=== END GAME STATE ==="
|
475
|
+
)
|
476
|
+
|
477
|
+
self._hist_obs.append(enhanced_obs)
|
478
|
+
self._hist_pub_state.append(pub)
|
479
|
+
self._hist_priv_state.append(priv)
|
480
|
+
|
481
|
+
observation_dict = {
|
482
|
+
"public": pub,
|
483
|
+
"private": priv,
|
484
|
+
"formatted_obs": enhanced_obs,
|
485
|
+
"history_formatted_obs": list(self._hist_obs),
|
486
|
+
"history_public_states": list(self._hist_pub_state),
|
487
|
+
"history_private_states": list(self._hist_priv_state),
|
488
|
+
}
|
489
|
+
|
490
|
+
# Only include screen buffer for screen mode
|
491
|
+
if self.mode == "state_and_screen":
|
492
|
+
observation_dict["screen_buffer"] = self.screen_buffer
|
493
|
+
|
494
|
+
return observation_dict # type: ignore[return-value]
|
495
|
+
|
496
|
+
|
497
|
+
# --- Pydantic Models for Tool Arguments ---
|
498
|
+
class PokemonRedInteractArgs(BaseModel):
|
499
|
+
buttons: List[str] = Field(
|
500
|
+
description="A sequence of 1-5 buttons to press in Pokemon Red (e.g., ['A'], ['UP', 'RIGHT'], ['START', 'DOWN', 'A']). Each button should be one of: A, B, UP, DOWN, LEFT, RIGHT, START, SELECT."
|
501
|
+
)
|
502
|
+
reasoning: str = Field(
|
503
|
+
description="A brief explanation of why this sequence of buttons was chosen and what you expect to accomplish."
|
504
|
+
)
|
505
|
+
|
506
|
+
@validator("buttons")
|
507
|
+
def validate_buttons(cls, v):
|
508
|
+
valid_buttons = {"A", "B", "UP", "DOWN", "LEFT", "RIGHT", "START", "SELECT"}
|
509
|
+
if not v or len(v) == 0:
|
510
|
+
raise ValueError("Must provide at least one button")
|
511
|
+
if len(v) > 5: # Reduced from 20 to 5
|
512
|
+
raise ValueError("Cannot provide more than 5 buttons in sequence")
|
513
|
+
for button in v:
|
514
|
+
if button.upper() not in valid_buttons:
|
515
|
+
raise ValueError(f"Invalid button: {button}. Valid buttons: {valid_buttons}")
|
516
|
+
return [button.upper() for button in v] # Normalize to uppercase
|
517
|
+
|
518
|
+
|
519
|
+
class TerminateArgs(BaseModel):
|
520
|
+
reason: str = Field(
|
521
|
+
description="Reason for termination (e.g., 'all tasks complete', 'stuck', 'max_steps_reached')."
|
522
|
+
)
|
523
|
+
|
524
|
+
|
525
|
+
# --- Environment tool call wrapper ---
|
526
|
+
class PressButtonCall(EnvToolCall):
|
527
|
+
"""Helper class for creating button press calls"""
|
528
|
+
|
529
|
+
def __init__(self, button: str, frames: int = 1):
|
530
|
+
super().__init__(tool="press_button", args={"button": button, "frames": frames})
|
531
|
+
|
532
|
+
|
533
|
+
# --- ReAct agent for Pokemon Red ---
|
534
|
+
class ReActAgent:
|
535
|
+
def __init__(self, llm, max_turns: int = 50):
|
536
|
+
self.llm, self.max_turns = llm, max_turns
|
537
|
+
self.history: List[Dict[str, Any]] = []
|
538
|
+
self.system_name: str = "pokemon-red-react"
|
539
|
+
self.system_id: Any = get_system_id(self.system_name)
|
540
|
+
self.system_instance_id: str = str(uuid.uuid4())
|
541
|
+
self.last_obs_dict: Optional[Dict[str, Any]] = None
|
542
|
+
self.current_badges: int = 0
|
543
|
+
|
544
|
+
# Valid button inputs for Pokemon Red
|
545
|
+
self.valid_buttons = [
|
546
|
+
"A",
|
547
|
+
"B",
|
548
|
+
"UP",
|
549
|
+
"DOWN",
|
550
|
+
"LEFT",
|
551
|
+
"RIGHT",
|
552
|
+
"START",
|
553
|
+
"SELECT",
|
554
|
+
]
|
555
|
+
|
556
|
+
# Create proper BaseTool objects for zyk
|
557
|
+
self.tools = [
|
558
|
+
BaseTool(
|
559
|
+
name="pokemon_red_interact",
|
560
|
+
description="Interacts with the Pokemon Red game by pressing a button.",
|
561
|
+
arguments=PokemonRedInteractArgs,
|
562
|
+
),
|
563
|
+
BaseTool(
|
564
|
+
name="terminate",
|
565
|
+
description="Terminates the agent's execution if the task is considered complete or no useful progress can be made.",
|
566
|
+
arguments=TerminateArgs,
|
567
|
+
),
|
568
|
+
]
|
569
|
+
|
570
|
+
def _format_history_for_prompt(self) -> str:
|
571
|
+
prompt_history = []
|
572
|
+
for entry in self.history:
|
573
|
+
if entry["type"] == "obs":
|
574
|
+
prompt_history.append(f"OBSERVATION:\n{entry['content']}")
|
575
|
+
elif entry["type"] == "tool_call":
|
576
|
+
args_str = json.dumps(entry["tool_arguments"])
|
577
|
+
prompt_history.append(
|
578
|
+
f"THOUGHT:\nI will call the tool `{entry['tool_name']}` with arguments: {args_str}\nACTION: (Tool call executed)"
|
579
|
+
)
|
580
|
+
elif entry["type"] == "tool_response":
|
581
|
+
prompt_history.append(
|
582
|
+
"TOOL_RESPONSE:\n(Button pressed, new observation will follow if not terminal)"
|
583
|
+
)
|
584
|
+
return "\n".join(prompt_history)
|
585
|
+
|
586
|
+
def _get_recent_reasoning_traces(self, k: int = 5) -> str:
|
587
|
+
"""Get the reasoning from the last k tool calls to help agent avoid repeating mistakes."""
|
588
|
+
recent_reasoning = []
|
589
|
+
tool_calls = [entry for entry in self.history if entry["type"] == "tool_call"]
|
590
|
+
|
591
|
+
# Get last k tool calls
|
592
|
+
for tool_call in tool_calls[-k:]:
|
593
|
+
if "tool_arguments" in tool_call and "reasoning" in tool_call["tool_arguments"]:
|
594
|
+
step_num = len(
|
595
|
+
[
|
596
|
+
e
|
597
|
+
for e in self.history[: self.history.index(tool_call) + 1]
|
598
|
+
if e["type"] == "tool_call"
|
599
|
+
]
|
600
|
+
)
|
601
|
+
reasoning = tool_call["tool_arguments"]["reasoning"]
|
602
|
+
buttons = tool_call["tool_arguments"].get("buttons", ["unknown"])
|
603
|
+
recent_reasoning.append(
|
604
|
+
f"Step {step_num}: Pressed {buttons} - Reasoning: {reasoning}"
|
605
|
+
)
|
606
|
+
|
607
|
+
if recent_reasoning:
|
608
|
+
# Add warning if same button pressed many times OR same button sequence repeated
|
609
|
+
if len(recent_reasoning) >= 3:
|
610
|
+
last_3_buttons = []
|
611
|
+
last_3_sequences = []
|
612
|
+
for trace in recent_reasoning[-3:]:
|
613
|
+
# Extract buttons from trace
|
614
|
+
if "Pressed ['" in trace:
|
615
|
+
start = trace.find("Pressed ['") + 10
|
616
|
+
end = trace.find("']", start)
|
617
|
+
if end > start:
|
618
|
+
buttons_str = trace[start:end]
|
619
|
+
# Handle both single buttons and sequences
|
620
|
+
if "', '" in buttons_str:
|
621
|
+
buttons = buttons_str.split("', '")
|
622
|
+
else:
|
623
|
+
buttons = [buttons_str]
|
624
|
+
last_3_buttons.append(buttons[0] if buttons else "unknown")
|
625
|
+
last_3_sequences.append(str(buttons))
|
626
|
+
|
627
|
+
# Check for repeated single button
|
628
|
+
if len(set(last_3_buttons)) == 1 and len(last_3_buttons) >= 3:
|
629
|
+
warning = f"\n⚠️ WARNING: You've pressed '{last_3_buttons[0]}' button {len(last_3_buttons)} times in a row! This button may not be working for the current situation. Try a different approach like pressing 'B' to cancel, or movement buttons to navigate away.\n"
|
630
|
+
return (
|
631
|
+
"RECENT REASONING HISTORY:\n" + "\n".join(recent_reasoning) + warning + "\n"
|
632
|
+
)
|
633
|
+
|
634
|
+
# Check for repeated button sequences
|
635
|
+
if len(set(last_3_sequences)) == 1 and len(last_3_sequences) >= 3:
|
636
|
+
warning = f"\n⚠️ WARNING: You've used the same button sequence {last_3_sequences[0]} {len(last_3_sequences)} times in a row! This sequence may not be working. Try a completely different approach like 'B' to cancel or different movement directions.\n"
|
637
|
+
return (
|
638
|
+
"RECENT REASONING HISTORY:\n" + "\n".join(recent_reasoning) + warning + "\n"
|
639
|
+
)
|
640
|
+
|
641
|
+
return "RECENT REASONING HISTORY:\n" + "\n".join(recent_reasoning) + "\n\n"
|
642
|
+
return ""
|
643
|
+
|
644
|
+
@trace_event_async(event_type="react_agent_decide")
|
645
|
+
async def decide(
|
646
|
+
self,
|
647
|
+
obs_str: str,
|
648
|
+
current_raw_obs: Dict[str, Any],
|
649
|
+
mode: Literal["state_and_ascii", "state_and_screen"] = "state_and_screen",
|
650
|
+
) -> List[str]:
|
651
|
+
print(f"[AGENT_DEBUG] Starting decide with obs: {obs_str[:100]}...")
|
652
|
+
self.history.append({"type": "obs", "content": obs_str})
|
653
|
+
self.last_obs_dict = current_raw_obs
|
654
|
+
|
655
|
+
# Update current badge count from the raw observation
|
656
|
+
if current_raw_obs and isinstance(current_raw_obs.get("public"), PokemonRedPublicState):
|
657
|
+
pub_state: PokemonRedPublicState = current_raw_obs["public"]
|
658
|
+
self.current_badges = pub_state.badges
|
659
|
+
|
660
|
+
print(f"[AGENT_DEBUG] History length: {len(self.history)}")
|
661
|
+
|
662
|
+
# Extract current step count for cache busting
|
663
|
+
current_step_count = 0
|
664
|
+
if current_raw_obs and isinstance(current_raw_obs.get("public"), PokemonRedPublicState):
|
665
|
+
pub_state: PokemonRedPublicState = current_raw_obs["public"]
|
666
|
+
current_step_count = pub_state.step_count
|
667
|
+
|
668
|
+
# Extract screen buffer for vision only in screen mode
|
669
|
+
screen_images_bytes = []
|
670
|
+
if mode == "state_and_screen":
|
671
|
+
try:
|
672
|
+
# Get screen buffer directly from the observation
|
673
|
+
if (
|
674
|
+
current_raw_obs
|
675
|
+
and "screen_buffer" in current_raw_obs
|
676
|
+
and current_raw_obs["screen_buffer"] is not None
|
677
|
+
):
|
678
|
+
screen_buffer = current_raw_obs["screen_buffer"]
|
679
|
+
print(f"[AGENT_DEBUG] Got screen buffer with shape: {screen_buffer.shape}")
|
680
|
+
|
681
|
+
# Convert screen buffer to base64 image
|
682
|
+
import base64
|
683
|
+
import io
|
684
|
+
from PIL import Image
|
685
|
+
import numpy as np
|
686
|
+
|
687
|
+
# Ensure the array is in the right format (0-255 uint8)
|
688
|
+
if screen_buffer.dtype != np.uint8:
|
689
|
+
if screen_buffer.max() <= 1.0:
|
690
|
+
screen_array = (screen_buffer * 255).astype(np.uint8)
|
691
|
+
else:
|
692
|
+
screen_array = screen_buffer.astype(np.uint8)
|
693
|
+
else:
|
694
|
+
screen_array = screen_buffer
|
695
|
+
|
696
|
+
# PyBoy screen format is (144, 160, 4) RGBA
|
697
|
+
if len(screen_array.shape) == 3 and screen_array.shape[2] == 4: # RGBA
|
698
|
+
# Convert RGBA to RGB by dropping alpha channel
|
699
|
+
image = Image.fromarray(screen_array[:, :, :3], mode="RGB")
|
700
|
+
else:
|
701
|
+
raise ValueError(f"Unsupported screen array shape: {screen_array.shape}")
|
702
|
+
|
703
|
+
# DEBUG: Save the image to debug directory
|
704
|
+
debug_dir = Path(__file__).parent / "debug"
|
705
|
+
debug_dir.mkdir(exist_ok=True)
|
706
|
+
debug_filename = (
|
707
|
+
f"step_{current_step_count:04d}_agent_{self.system_instance_id[-8:]}.png"
|
708
|
+
)
|
709
|
+
debug_path = debug_dir / debug_filename
|
710
|
+
image.save(debug_path)
|
711
|
+
print(f"[DEBUG] Saved screen image to: {debug_path}")
|
712
|
+
|
713
|
+
# Convert to base64
|
714
|
+
buffer = io.BytesIO()
|
715
|
+
image.save(buffer, format="PNG")
|
716
|
+
buffer.seek(0)
|
717
|
+
base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
718
|
+
screen_images_bytes = [base64_image]
|
719
|
+
print("[AGENT_DEBUG] Successfully converted screen to base64 image")
|
720
|
+
else:
|
721
|
+
print("[AGENT_DEBUG] No screen buffer available in observation")
|
722
|
+
|
723
|
+
except Exception as e:
|
724
|
+
print(f"[AGENT_DEBUG] Failed to extract screen buffer: {e}")
|
725
|
+
# Continue without screen - the text observation should still work
|
726
|
+
|
727
|
+
# Create appropriate prompt based on mode
|
728
|
+
if mode == "state_and_ascii":
|
729
|
+
prompt = (
|
730
|
+
f"{self._get_recent_reasoning_traces(k=5)}"
|
731
|
+
f"CURRENT OBSERVATION:\n{obs_str}\n\n"
|
732
|
+
"Based on the game state text and ASCII representation above, "
|
733
|
+
"what is your reasoning and which tool (`pokemon_red_interact` or `terminate`) should you call next? "
|
734
|
+
"The ASCII representation shows the visual layout of the screen. "
|
735
|
+
"Look at your recent reasoning history to avoid repeating the same ineffective actions. "
|
736
|
+
"Focus on making progress: collect badges, heal when HP is low, explore new areas, and interact with the world.\n"
|
737
|
+
f"[Turn: {current_step_count}]"
|
738
|
+
)
|
739
|
+
else: # state_and_screen
|
740
|
+
prompt = (
|
741
|
+
f"{self._get_recent_reasoning_traces(k=5)}"
|
742
|
+
f"CURRENT OBSERVATION:\n{obs_str}\n\n"
|
743
|
+
"Based on the game state text above AND the game screen image (if provided), "
|
744
|
+
"what is your reasoning and which tool (`pokemon_red_interact` or `terminate`) should you call next? "
|
745
|
+
"Look at both the text information and the visual screen to understand what's happening in the game. "
|
746
|
+
"Look at your recent reasoning history to avoid repeating the same ineffective actions. "
|
747
|
+
"Focus on making progress: collect badges, heal when HP is low, explore new areas, and interact with the world.\n"
|
748
|
+
f"[Turn: {current_step_count}]"
|
749
|
+
)
|
750
|
+
|
751
|
+
system_message = (
|
752
|
+
"You are an agent playing Pokemon Red. You receive structured game state information "
|
753
|
+
"and can execute button sequences to interact with the game. "
|
754
|
+
"Your goal is to progress through the game by collecting badges, training Pokemon, and exploring.\n\n"
|
755
|
+
"GAME STATE INFORMATION:\n"
|
756
|
+
"You receive detailed information about:\n"
|
757
|
+
"• World Location: Current map ID and position coordinates\n"
|
758
|
+
"• Player Progress: Badge count and money\n"
|
759
|
+
"• Pokemon Party: Each Pokemon's species, level, HP, and XP\n"
|
760
|
+
"• Inventory: Items with quantities\n"
|
761
|
+
"• Game System State: Raw system flags and states\n"
|
762
|
+
)
|
763
|
+
|
764
|
+
if mode == "state_and_ascii":
|
765
|
+
system_message += (
|
766
|
+
"• Visual Screen Analysis: ASCII representation and entity detection\n\n"
|
767
|
+
)
|
768
|
+
else:
|
769
|
+
system_message += (
|
770
|
+
"• Visual Screen Analysis: ASCII representation and actual screen images\n\n"
|
771
|
+
)
|
772
|
+
|
773
|
+
system_message += (
|
774
|
+
"AVAILABLE ACTIONS:\n"
|
775
|
+
"You can execute sequences of 1-5 buttons. Use as many button presses as are appropriate - sometimes 1 or 2, occasionally 3-5:\n"
|
776
|
+
f"• Available buttons: {', '.join(self.valid_buttons)}\n"
|
777
|
+
"• Examples: ['A'], ['UP', 'RIGHT'], ['START', 'DOWN', 'A']\n\n"
|
778
|
+
"IMPORTANT GUIDANCE:\n"
|
779
|
+
"• If 'Text Box Active: True' and A button isn't working, try B to cancel or navigate away\n"
|
780
|
+
"• If you're repeating the same button many times without progress, try a different approach\n"
|
781
|
+
"• When stuck, try movement buttons (UP, DOWN, LEFT, RIGHT) to explore or navigate menus\n"
|
782
|
+
"• B button often cancels menus or text boxes when A doesn't work\n"
|
783
|
+
"• Look at your recent reasoning history to avoid ineffective repeated actions\n"
|
784
|
+
"• Use shorter button sequences (1-3 buttons) rather than long sequences\n"
|
785
|
+
"• If the same action doesn't work after 2-3 tries, try something completely different\n\n"
|
786
|
+
"TOOLS AVAILABLE:\n"
|
787
|
+
f"• pokemon_red_interact: Execute button sequences\n"
|
788
|
+
"• terminate: End the session\n\n"
|
789
|
+
"Make decisions based on the game state information provided. "
|
790
|
+
"Always provide reasoning that references the specific state information."
|
791
|
+
)
|
792
|
+
|
793
|
+
print("=" * 80)
|
794
|
+
print("[AI_INPUT] SYSTEM MESSAGE:")
|
795
|
+
print(system_message)
|
796
|
+
print("-" * 40)
|
797
|
+
print("[AI_INPUT] USER MESSAGE:")
|
798
|
+
print(prompt)
|
799
|
+
print("-" * 40)
|
800
|
+
print("[AI_INPUT] TOOLS:")
|
801
|
+
print(json.dumps([tool.to_openai_tool() for tool in self.tools], indent=2))
|
802
|
+
print("-" * 40)
|
803
|
+
print(f"[AI_INPUT] IMAGES: {len(screen_images_bytes)} image(s) provided")
|
804
|
+
print("=" * 80)
|
805
|
+
|
806
|
+
print(
|
807
|
+
f"[AGENT_DEBUG] Calling LLM with prompt length: {len(prompt)}, images: {len(screen_images_bytes)}"
|
808
|
+
)
|
809
|
+
response_obj = await self.llm.respond_async(
|
810
|
+
system_message=system_message,
|
811
|
+
user_message=prompt,
|
812
|
+
tools=self.tools,
|
813
|
+
images_as_bytes=screen_images_bytes,
|
814
|
+
)
|
815
|
+
print("[AGENT_DEBUG] LLM response received")
|
816
|
+
|
817
|
+
print("=" * 80)
|
818
|
+
print("[AI_OUTPUT] RESPONSE OBJECT:")
|
819
|
+
print(f"Response type: {type(response_obj)}")
|
820
|
+
print(f"Response content: {response_obj}")
|
821
|
+
if hasattr(response_obj, "tool_calls"):
|
822
|
+
print(f"Tool calls: {response_obj.tool_calls}")
|
823
|
+
if hasattr(response_obj, "content"):
|
824
|
+
print(f"Content: {response_obj.content}")
|
825
|
+
print("=" * 80)
|
826
|
+
|
827
|
+
assert response_obj.tool_calls, "Response object didn't have tool call"
|
828
|
+
tool_calls = None
|
829
|
+
|
830
|
+
try:
|
831
|
+
if hasattr(response_obj, "tool_calls") and response_obj.tool_calls:
|
832
|
+
tool_calls = response_obj.tool_calls
|
833
|
+
print(f"[AGENT_DEBUG] Found {len(tool_calls)} tool calls")
|
834
|
+
|
835
|
+
if not tool_calls:
|
836
|
+
print("[AGENT_DEBUG] No tool calls found, falling back to A")
|
837
|
+
self.history.append(
|
838
|
+
{
|
839
|
+
"type": "tool_call",
|
840
|
+
"tool_name": "pokemon_red_interact",
|
841
|
+
"tool_arguments": {
|
842
|
+
"button": "A",
|
843
|
+
"reasoning": "LLM failed to provide tool_calls, fallback to A button.",
|
844
|
+
},
|
845
|
+
}
|
846
|
+
)
|
847
|
+
return ["A"]
|
848
|
+
|
849
|
+
if len(tool_calls) == 0:
|
850
|
+
print("[AGENT_DEBUG] Empty tool calls list, falling back to A")
|
851
|
+
self.history.append(
|
852
|
+
{"type": "error", "content": "LLM returned empty tool_calls list."}
|
853
|
+
)
|
854
|
+
return ["A"]
|
855
|
+
|
856
|
+
tool_call_data = tool_calls[0]
|
857
|
+
tool_name = ""
|
858
|
+
tool_args_str = ""
|
859
|
+
|
860
|
+
if (
|
861
|
+
hasattr(tool_call_data, "function")
|
862
|
+
and hasattr(tool_call_data.function, "name")
|
863
|
+
and hasattr(tool_call_data.function, "arguments")
|
864
|
+
):
|
865
|
+
tool_name = tool_call_data.function.name
|
866
|
+
tool_args_str = tool_call_data.function.arguments
|
867
|
+
elif (
|
868
|
+
isinstance(tool_call_data, dict)
|
869
|
+
and "function" in tool_call_data
|
870
|
+
and isinstance(tool_call_data["function"], dict)
|
871
|
+
):
|
872
|
+
tool_name = tool_call_data["function"].get("name")
|
873
|
+
tool_args_str = tool_call_data["function"].get("arguments")
|
874
|
+
if not isinstance(tool_args_str, str):
|
875
|
+
tool_arguments_dict = tool_args_str
|
876
|
+
tool_args_str = json.dumps(tool_arguments_dict)
|
877
|
+
else:
|
878
|
+
tool_arguments_dict = json.loads(tool_args_str)
|
879
|
+
else:
|
880
|
+
print("[AGENT_DEBUG] Unexpected tool_call structure, falling back to A")
|
881
|
+
self.history.append({"type": "error", "content": "Unexpected tool_call structure."})
|
882
|
+
return ["A"]
|
883
|
+
|
884
|
+
print(f"[AGENT_DEBUG] Tool name: {tool_name}, Args: {tool_args_str}")
|
885
|
+
|
886
|
+
if not tool_args_str:
|
887
|
+
print(f"[AGENT_DEBUG] Missing arguments for tool {tool_name}, falling back to A")
|
888
|
+
self.history.append(
|
889
|
+
{
|
890
|
+
"type": "error",
|
891
|
+
"content": f"Missing arguments for tool {tool_name}. Args string: '{tool_args_str}'",
|
892
|
+
}
|
893
|
+
)
|
894
|
+
return ["A"]
|
895
|
+
|
896
|
+
tool_arguments = json.loads(tool_args_str)
|
897
|
+
|
898
|
+
self.history.append(
|
899
|
+
{
|
900
|
+
"type": "tool_call",
|
901
|
+
"tool_name": tool_name,
|
902
|
+
"tool_arguments": tool_arguments,
|
903
|
+
}
|
904
|
+
)
|
905
|
+
|
906
|
+
if tool_name == "pokemon_red_interact":
|
907
|
+
print("[AGENT_DEBUG] Processing pokemon_red_interact tool call")
|
908
|
+
validated_args = PokemonRedInteractArgs(**tool_arguments)
|
909
|
+
buttons = validated_args.buttons
|
910
|
+
print(
|
911
|
+
f"[AGENT_DEBUG] Buttons: {buttons}, Valid: {[button in self.valid_buttons for button in buttons]}"
|
912
|
+
)
|
913
|
+
|
914
|
+
invalid_buttons = [button for button in buttons if button not in self.valid_buttons]
|
915
|
+
if invalid_buttons:
|
916
|
+
print(f"[AGENT_DEBUG] Invalid buttons: {invalid_buttons}, falling back to A")
|
917
|
+
self.history.append(
|
918
|
+
{
|
919
|
+
"type": "error",
|
920
|
+
"content": f"Invalid buttons: {invalid_buttons}. Falling back to A.",
|
921
|
+
}
|
922
|
+
)
|
923
|
+
return ["A"]
|
924
|
+
print(f"[AGENT_DEBUG] Returning buttons: {buttons}")
|
925
|
+
return buttons
|
926
|
+
|
927
|
+
elif tool_name == "terminate":
|
928
|
+
print("[AGENT_DEBUG] Processing terminate tool call")
|
929
|
+
# Allow termination if agent decides
|
930
|
+
print("[AGENT_DEBUG] Agent decided to terminate, returning TERMINATE")
|
931
|
+
return ["TERMINATE"]
|
932
|
+
|
933
|
+
else:
|
934
|
+
print(f"[AGENT_DEBUG] Unknown tool_name: {tool_name}, falling back to A")
|
935
|
+
self.history.append({"type": "error", "content": f"Unknown tool_name: {tool_name}"})
|
936
|
+
return ["A"]
|
937
|
+
|
938
|
+
except Exception as e:
|
939
|
+
error_content = (
|
940
|
+
f"Error processing LLM response: {str(e)}. Response: {str(response_obj)[:500]}"
|
941
|
+
)
|
942
|
+
print(f"[AGENT_DEBUG] Exception in decide: {error_content}")
|
943
|
+
self.history.append({"type": "error", "content": error_content})
|
944
|
+
return ["A"]
|
945
|
+
|
946
|
+
|
947
|
+
# --- Test for a single agent run ---
|
948
|
+
@pytest.mark.asyncio
|
949
|
+
async def test_react_agent_pokemon_red(
|
950
|
+
tmp_path: Path,
|
951
|
+
mode: Literal["state_and_ascii", "state_and_screen"] = "state_and_screen",
|
952
|
+
):
|
953
|
+
# Create a simple Pokemon Red task instance for testing
|
954
|
+
task_metadata = TaskInstanceMetadata()
|
955
|
+
inst = PokemonRedTaskInstance(
|
956
|
+
id=uuid.uuid4(),
|
957
|
+
impetus=Impetus(instructions="Start your Pokemon journey and collect badges."),
|
958
|
+
intent=Intent(
|
959
|
+
rubric={"goal": "Collect badges and progress"},
|
960
|
+
gold_trajectories=None,
|
961
|
+
gold_state_diff={},
|
962
|
+
),
|
963
|
+
metadata=task_metadata,
|
964
|
+
is_reproducible=True,
|
965
|
+
initial_engine_snapshot=None,
|
966
|
+
)
|
967
|
+
|
968
|
+
hist_cb = PokemonRedHistoryObservationCallable(max_history=1, mode=mode)
|
969
|
+
env = PokemonRedEnvironment(inst, custom_step_obs=hist_cb)
|
970
|
+
|
971
|
+
llm = LM(model_name="gpt-4.1-nano", formatting_model_name="gpt-4.1-nano", temperature=0.0)
|
972
|
+
agent = ReActAgent(llm, max_turns=30)
|
973
|
+
|
974
|
+
async def run_episode():
|
975
|
+
obs_payload = await env.initialize()
|
976
|
+
|
977
|
+
if "error" in obs_payload:
|
978
|
+
print(f"Error during env.initialize: {obs_payload['error']}")
|
979
|
+
return False, 0
|
980
|
+
|
981
|
+
current_formatted_obs = obs_payload["formatted_obs"]
|
982
|
+
raw_obs_for_agent_decision = obs_payload
|
983
|
+
|
984
|
+
for turn in range(agent.max_turns):
|
985
|
+
buttons = await agent.decide(current_formatted_obs, raw_obs_for_agent_decision, mode)
|
986
|
+
|
987
|
+
if "TERMINATE" in buttons:
|
988
|
+
obs_payload_next = obs_payload
|
989
|
+
break
|
990
|
+
|
991
|
+
# Execute button sequence one by one
|
992
|
+
for i, button in enumerate(buttons):
|
993
|
+
print(f"[DEBUG] Executing button {i + 1}/{len(buttons)}: {button}")
|
994
|
+
obs_payload_next = await env.step([[PressButtonCall(button)]])
|
995
|
+
|
996
|
+
if "error" in obs_payload_next:
|
997
|
+
raise RuntimeError(
|
998
|
+
f"Environment step error on button {i + 1}: {obs_payload_next['error']}"
|
999
|
+
)
|
1000
|
+
|
1001
|
+
# Update observation after each button press
|
1002
|
+
obs_payload = obs_payload_next
|
1003
|
+
|
1004
|
+
# Check if environment terminated after this button
|
1005
|
+
if obs_payload["private"].terminated or obs_payload["private"].truncated:
|
1006
|
+
print(
|
1007
|
+
f"[DEBUG] Environment terminated/truncated after button {i + 1}/{len(buttons)}"
|
1008
|
+
)
|
1009
|
+
break
|
1010
|
+
|
1011
|
+
if "obs_payload_next" not in locals():
|
1012
|
+
obs_payload_next = obs_payload
|
1013
|
+
|
1014
|
+
if "error" in obs_payload_next:
|
1015
|
+
return False, agent.current_badges
|
1016
|
+
|
1017
|
+
final_private_state: PokemonRedPrivateState = obs_payload_next["private"]
|
1018
|
+
episode_successful = final_private_state.terminated or final_private_state.truncated
|
1019
|
+
return episode_successful, agent.current_badges
|
1020
|
+
|
1021
|
+
episode_completed, badges_collected = await run_episode()
|
1022
|
+
|
1023
|
+
dataset = Dataset(
|
1024
|
+
questions=[
|
1025
|
+
TrainingQuestion(
|
1026
|
+
id="pokemon_red_ep_test",
|
1027
|
+
intent="progress_in_game",
|
1028
|
+
criteria="completed_episode_or_collected_badges",
|
1029
|
+
)
|
1030
|
+
],
|
1031
|
+
reward_signals=[
|
1032
|
+
RewardSignal(
|
1033
|
+
question_id="pokemon_red_ep_test",
|
1034
|
+
run_id=agent.system_instance_id,
|
1035
|
+
system_instance_id=agent.system_instance_id,
|
1036
|
+
reward=1 if episode_completed or badges_collected > 0 else 0,
|
1037
|
+
error_message="" if episode_completed else "Episode not completed as expected.",
|
1038
|
+
metadata={
|
1039
|
+
"agent_history": agent.history,
|
1040
|
+
"badges_collected": badges_collected,
|
1041
|
+
"total_reward_earned": hist_cb.reward_manager.total_reward_earned,
|
1042
|
+
"reward_history": hist_cb.reward_manager.reward_history,
|
1043
|
+
},
|
1044
|
+
)
|
1045
|
+
],
|
1046
|
+
)
|
1047
|
+
# upload(dataset=dataset) # Optional: uncomment to upload trace
|
1048
|
+
|
1049
|
+
assert episode_completed or badges_collected > 0, (
|
1050
|
+
"Agent failed to complete the episode or collect any badges in the test."
|
1051
|
+
)
|
1052
|
+
|
1053
|
+
|
1054
|
+
async def eval_react_pokemon_red(
|
1055
|
+
model_name: str = "gpt-4o-mini",
|
1056
|
+
max_turns: int = 20,
|
1057
|
+
mode: Literal["state_and_ascii", "state_and_screen"] = "state_and_screen",
|
1058
|
+
) -> None:
|
1059
|
+
"""
|
1060
|
+
Run ReAct agents on Pokemon Red instances of different difficulties,
|
1061
|
+
and print aggregated success rates and average badges collected.
|
1062
|
+
"""
|
1063
|
+
from tabulate import tabulate
|
1064
|
+
|
1065
|
+
current_model_name_for_eval = model_name
|
1066
|
+
|
1067
|
+
_temp_llm_for_names = LM(
|
1068
|
+
model_name=current_model_name_for_eval,
|
1069
|
+
formatting_model_name=current_model_name_for_eval,
|
1070
|
+
temperature=0.0,
|
1071
|
+
)
|
1072
|
+
_temp_agent_for_names = ReActAgent(_temp_llm_for_names)
|
1073
|
+
actual_system_name = _temp_agent_for_names.system_name
|
1074
|
+
|
1075
|
+
# ------------------------------------------------------------------ helpers
|
1076
|
+
async def run_episode_eval(
|
1077
|
+
inst: PokemonRedTaskInstance, agent_max_turns: int
|
1078
|
+
) -> tuple[bool, int, float, list]:
|
1079
|
+
"""Run a single agent/instance episode and return (success_status, badges_collected, total_rewards, reward_history)."""
|
1080
|
+
print(f"[DEBUG] Starting episode for instance {inst.id}")
|
1081
|
+
hist_cb = PokemonRedHistoryObservationCallable(max_history=1, mode=mode)
|
1082
|
+
env = PokemonRedEnvironment(inst, custom_step_obs=hist_cb)
|
1083
|
+
|
1084
|
+
llm_for_episode = LM(
|
1085
|
+
model_name=current_model_name_for_eval,
|
1086
|
+
formatting_model_name=current_model_name_for_eval,
|
1087
|
+
temperature=0.0,
|
1088
|
+
)
|
1089
|
+
agent = ReActAgent(llm_for_episode, max_turns=agent_max_turns)
|
1090
|
+
print(f"[DEBUG] Created agent with max_turns={agent_max_turns}")
|
1091
|
+
|
1092
|
+
print("[DEBUG] Initializing environment...")
|
1093
|
+
obs_payload = await env.initialize()
|
1094
|
+
print(
|
1095
|
+
f"[DEBUG] Environment initialized. Obs keys: {list(obs_payload.keys()) if isinstance(obs_payload, dict) else type(obs_payload)}"
|
1096
|
+
)
|
1097
|
+
if "error" in obs_payload:
|
1098
|
+
raise RuntimeError(f"Environment initialization failed: {obs_payload['error']}")
|
1099
|
+
|
1100
|
+
current_formatted_obs = obs_payload["formatted_obs"]
|
1101
|
+
raw_obs_for_agent_decision = obs_payload
|
1102
|
+
print(f"[DEBUG] Initial formatted obs: {current_formatted_obs[:200]}...")
|
1103
|
+
|
1104
|
+
# Track state changes to detect if agent is stuck
|
1105
|
+
last_position = None
|
1106
|
+
last_map_id = None
|
1107
|
+
stuck_count = 0
|
1108
|
+
same_button_count = 0
|
1109
|
+
last_button = None
|
1110
|
+
|
1111
|
+
turn_count = 0
|
1112
|
+
for turn_idx in range(agent.max_turns):
|
1113
|
+
turn_count += 1
|
1114
|
+
print(f"[DEBUG] === Turn {turn_idx + 1}/{agent.max_turns} ===")
|
1115
|
+
print(f"[DEBUG] Agent deciding on obs: {current_formatted_obs[:100]}...")
|
1116
|
+
|
1117
|
+
buttons = await agent.decide(current_formatted_obs, raw_obs_for_agent_decision, mode)
|
1118
|
+
print(f"[DEBUG] Agent decided buttons: {buttons}")
|
1119
|
+
|
1120
|
+
# Check for repeated button presses
|
1121
|
+
if buttons[0] == last_button:
|
1122
|
+
same_button_count += 1
|
1123
|
+
# Increased tolerance since engine now handles retries automatically
|
1124
|
+
# and some game states may legitimately require the same button multiple times
|
1125
|
+
if same_button_count >= 8: # Increased from 4 to 8
|
1126
|
+
print(
|
1127
|
+
f"[WARNING] Agent pressed same button '{buttons[0]}' {same_button_count} times in a row"
|
1128
|
+
)
|
1129
|
+
# Don't hard fail anymore - let the engine's retry mechanism handle it
|
1130
|
+
# raise RuntimeError(f"Agent pressing same button '{buttons[0]}' {same_button_count} times in a row - HARD FAIL")
|
1131
|
+
else:
|
1132
|
+
same_button_count = 1
|
1133
|
+
last_button = buttons[0]
|
1134
|
+
|
1135
|
+
if "TERMINATE" in buttons:
|
1136
|
+
print(f"[DEBUG] Agent decided to terminate after {turn_count} turns")
|
1137
|
+
break
|
1138
|
+
|
1139
|
+
print(f"[DEBUG] Stepping environment with buttons {buttons}")
|
1140
|
+
|
1141
|
+
try:
|
1142
|
+
# Execute button sequence one by one
|
1143
|
+
for i, button in enumerate(buttons):
|
1144
|
+
print(f"[DEBUG] Executing button {i + 1}/{len(buttons)}: {button}")
|
1145
|
+
obs_payload_next = await env.step([[PressButtonCall(button)]])
|
1146
|
+
|
1147
|
+
if "error" in obs_payload_next:
|
1148
|
+
raise RuntimeError(
|
1149
|
+
f"Environment step error on button {i + 1}: {obs_payload_next['error']}"
|
1150
|
+
)
|
1151
|
+
|
1152
|
+
# Update observation after each button press
|
1153
|
+
obs_payload = obs_payload_next
|
1154
|
+
|
1155
|
+
# Check if environment terminated after this button
|
1156
|
+
if obs_payload["private"].terminated or obs_payload["private"].truncated:
|
1157
|
+
print(
|
1158
|
+
f"[DEBUG] Environment terminated/truncated after button {i + 1}/{len(buttons)}"
|
1159
|
+
)
|
1160
|
+
break
|
1161
|
+
except RuntimeError as e:
|
1162
|
+
if "HARD FAIL" in str(e):
|
1163
|
+
raise # Re-raise hard failures immediately
|
1164
|
+
raise RuntimeError(f"Environment step failed: {e}")
|
1165
|
+
|
1166
|
+
print(
|
1167
|
+
f"[DEBUG] Environment step completed. Obs keys: {list(obs_payload.keys()) if isinstance(obs_payload, dict) else type(obs_payload)}"
|
1168
|
+
)
|
1169
|
+
|
1170
|
+
if "error" in obs_payload:
|
1171
|
+
raise RuntimeError(f"Environment step error: {obs_payload['error']}")
|
1172
|
+
|
1173
|
+
# Check if state is changing meaningfully using screen buffer hashes
|
1174
|
+
screen_changed = True
|
1175
|
+
if obs_payload.get("screen_buffer") is not None:
|
1176
|
+
import hashlib
|
1177
|
+
|
1178
|
+
current_screen_hash = hashlib.md5(
|
1179
|
+
obs_payload["screen_buffer"].tobytes()
|
1180
|
+
).hexdigest()
|
1181
|
+
if not hasattr(run_episode_eval, "last_screen_hash"):
|
1182
|
+
run_episode_eval.last_screen_hash = None
|
1183
|
+
run_episode_eval.same_screen_count = 0
|
1184
|
+
|
1185
|
+
if run_episode_eval.last_screen_hash == current_screen_hash:
|
1186
|
+
run_episode_eval.same_screen_count += 1
|
1187
|
+
screen_changed = False
|
1188
|
+
else:
|
1189
|
+
run_episode_eval.same_screen_count = 0
|
1190
|
+
screen_changed = True
|
1191
|
+
|
1192
|
+
run_episode_eval.last_screen_hash = current_screen_hash
|
1193
|
+
print(
|
1194
|
+
f"[DEBUG] Screen hash: {current_screen_hash[:8]}..., Same count: {run_episode_eval.same_screen_count}, Changed: {screen_changed}"
|
1195
|
+
)
|
1196
|
+
|
1197
|
+
# More intelligent failure detection for Pokemon Red
|
1198
|
+
# Based on investigation: menu_state=1 is normal overworld state, not a stuck condition
|
1199
|
+
# B button doing nothing is often expected (no menu to close)
|
1200
|
+
button_tolerance = {
|
1201
|
+
"B": 15, # B often does nothing in overworld - very lenient
|
1202
|
+
"A": 10, # A for interactions/dialogue - moderately lenient
|
1203
|
+
"START": 8, # START for menu opening - moderate
|
1204
|
+
"SELECT": 8, # SELECT for menu navigation - moderate
|
1205
|
+
"UP": 5, # Movement buttons - less lenient
|
1206
|
+
"DOWN": 5,
|
1207
|
+
"LEFT": 5,
|
1208
|
+
"RIGHT": 5,
|
1209
|
+
}
|
1210
|
+
|
1211
|
+
max_same_button = button_tolerance.get(
|
1212
|
+
buttons[0], 5
|
1213
|
+
) # Default to 5 for unknown buttons
|
1214
|
+
min_screen_unchanged = 12 # Increased - Pokemon Red often has static screens
|
1215
|
+
min_turn_threshold = 10 # Increased - allow more exploration time
|
1216
|
+
|
1217
|
+
# Only fail if BOTH conditions are met:
|
1218
|
+
# 1. Screen hasn't changed for many turns (visual stuckness)
|
1219
|
+
# 2. Agent is repeating ineffective actions beyond reasonable tolerance
|
1220
|
+
if (
|
1221
|
+
run_episode_eval.same_screen_count >= min_screen_unchanged
|
1222
|
+
and turn_idx > min_turn_threshold
|
1223
|
+
and same_button_count >= max_same_button
|
1224
|
+
):
|
1225
|
+
# Additional check: don't fail on B button if menu_state indicates normal overworld
|
1226
|
+
if buttons[0] == "B":
|
1227
|
+
# B button in overworld is often ineffective but not necessarily wrong
|
1228
|
+
# Just be more lenient with B button in general
|
1229
|
+
if same_button_count < 20: # Much more lenient for B button
|
1230
|
+
print(
|
1231
|
+
f"[DEBUG] B button often ineffective in overworld - allowing more attempts ({same_button_count}/20)"
|
1232
|
+
)
|
1233
|
+
# Continue without failing
|
1234
|
+
obs_payload = obs_payload_next
|
1235
|
+
continue
|
1236
|
+
|
1237
|
+
print(
|
1238
|
+
f"[WARNING] Agent appears stuck - screen unchanged for {run_episode_eval.same_screen_count} turns with repeated button '{buttons[0]}' {same_button_count} times"
|
1239
|
+
)
|
1240
|
+
print(
|
1241
|
+
f"[WARNING] Button tolerance for '{buttons[0]}': {max_same_button}, screen unchanged threshold: {min_screen_unchanged}"
|
1242
|
+
)
|
1243
|
+
raise RuntimeError(
|
1244
|
+
f"Agent stuck - screen unchanged for {run_episode_eval.same_screen_count} turns with repeated button '{buttons[0]}' ({same_button_count} times, tolerance: {max_same_button}) - HARD FAIL"
|
1245
|
+
)
|
1246
|
+
|
1247
|
+
# Legacy position-based detection (keep as fallback but make more lenient)
|
1248
|
+
current_pub = obs_payload["public"]
|
1249
|
+
current_position = (current_pub.player_x, current_pub.player_y)
|
1250
|
+
current_map_id = current_pub.map_id
|
1251
|
+
|
1252
|
+
# Only check position-based stuck if screen is also not changing
|
1253
|
+
if (
|
1254
|
+
last_position == current_position
|
1255
|
+
and last_map_id == current_map_id
|
1256
|
+
and not screen_changed
|
1257
|
+
and turn_idx > 8
|
1258
|
+
): # Much more lenient - allow many turns for dialogue
|
1259
|
+
stuck_count += 1
|
1260
|
+
if stuck_count >= 8: # Require many more turns of true stuck state
|
1261
|
+
raise RuntimeError(
|
1262
|
+
f"Agent truly stuck - no position or screen changes for {stuck_count} turns. Position: {current_position}, Map: {current_map_id} - HARD FAIL"
|
1263
|
+
)
|
1264
|
+
else:
|
1265
|
+
stuck_count = 0
|
1266
|
+
|
1267
|
+
last_position = current_position
|
1268
|
+
last_map_id = current_map_id
|
1269
|
+
|
1270
|
+
current_formatted_obs = obs_payload["formatted_obs"]
|
1271
|
+
raw_obs_for_agent_decision = obs_payload
|
1272
|
+
|
1273
|
+
agent.history.append(
|
1274
|
+
{
|
1275
|
+
"type": "tool_response",
|
1276
|
+
"content": f"Button sequence executed: {buttons}",
|
1277
|
+
}
|
1278
|
+
)
|
1279
|
+
|
1280
|
+
print(f"[DEBUG] New formatted obs: {current_formatted_obs[:100]}...")
|
1281
|
+
|
1282
|
+
if obs_payload["private"].terminated or obs_payload["private"].truncated:
|
1283
|
+
print(f"[DEBUG] Environment terminated/truncated after {turn_count} turns")
|
1284
|
+
print(
|
1285
|
+
f"[DEBUG] Terminated: {obs_payload['private'].terminated}, Truncated: {obs_payload['private'].truncated}"
|
1286
|
+
)
|
1287
|
+
break
|
1288
|
+
|
1289
|
+
print(f"[DEBUG] Episode completed after {turn_count} turns")
|
1290
|
+
final_private_state: PokemonRedPrivateState = obs_payload["private"]
|
1291
|
+
run_successful = final_private_state.terminated or final_private_state.truncated
|
1292
|
+
badges_collected = agent.current_badges
|
1293
|
+
total_rewards = hist_cb.reward_manager.total_reward_earned
|
1294
|
+
print(
|
1295
|
+
f"[DEBUG] Episode result - successful: {run_successful}, badges: {badges_collected}, rewards: {total_rewards:.1f}"
|
1296
|
+
)
|
1297
|
+
print(
|
1298
|
+
f"[DEBUG] Final private state - terminated: {final_private_state.terminated}, truncated: {final_private_state.truncated}"
|
1299
|
+
)
|
1300
|
+
print(f"[DEBUG] Total reward: {final_private_state.total_reward}")
|
1301
|
+
return (
|
1302
|
+
run_successful,
|
1303
|
+
badges_collected,
|
1304
|
+
total_rewards,
|
1305
|
+
hist_cb.reward_manager.reward_history,
|
1306
|
+
)
|
1307
|
+
|
1308
|
+
# ---------------------------------------------------------------- instance factory
|
1309
|
+
async def make_pokemon_red_instances(
|
1310
|
+
difficulty: str, n_instances: int = 3, start_seed: int = 0
|
1311
|
+
) -> List[PokemonRedTaskInstance]:
|
1312
|
+
instances = []
|
1313
|
+
|
1314
|
+
for i in range(n_instances):
|
1315
|
+
current_seed = start_seed + i
|
1316
|
+
metadata = TaskInstanceMetadata()
|
1317
|
+
instance = PokemonRedTaskInstance(
|
1318
|
+
id=uuid.uuid4(),
|
1319
|
+
impetus=Impetus(
|
1320
|
+
instructions=f"Play Pokemon Red on {difficulty} difficulty and collect badges."
|
1321
|
+
),
|
1322
|
+
intent=Intent(rubric={}, gold_trajectories=None, gold_state_diff={}),
|
1323
|
+
metadata=metadata,
|
1324
|
+
is_reproducible=True,
|
1325
|
+
initial_engine_snapshot=None,
|
1326
|
+
)
|
1327
|
+
instances.append(instance)
|
1328
|
+
return instances
|
1329
|
+
|
1330
|
+
# ---------------------------------------------------------------- evaluation
|
1331
|
+
configs = [
|
1332
|
+
(
|
1333
|
+
"easy",
|
1334
|
+
1,
|
1335
|
+
max_turns,
|
1336
|
+
), # (difficulty_label, num_agents/instances, max_turns_per_episode) - Use parameter
|
1337
|
+
]
|
1338
|
+
table_rows = []
|
1339
|
+
base_seed_for_difficulty = {"easy": 1000, "hard": 2000}
|
1340
|
+
|
1341
|
+
print("Starting Pokemon Red ReAct Agent Evaluation...")
|
1342
|
+
print(f"Model: {current_model_name_for_eval}, System: {actual_system_name}")
|
1343
|
+
|
1344
|
+
all_generated_task_data = []
|
1345
|
+
all_reward_achievements = {} # Track all rewards across all runs
|
1346
|
+
|
1347
|
+
print("\nGenerating task instances...")
|
1348
|
+
all_tasks_for_eval: Dict[str, List[PokemonRedTaskInstance]] = {}
|
1349
|
+
for label, num_agents, _ in configs:
|
1350
|
+
insts = await make_pokemon_red_instances(
|
1351
|
+
label, n_instances=num_agents, start_seed=base_seed_for_difficulty[label]
|
1352
|
+
)
|
1353
|
+
all_tasks_for_eval[label] = insts
|
1354
|
+
for inst in insts:
|
1355
|
+
instance_dict = await inst.serialize()
|
1356
|
+
all_generated_task_data.append(instance_dict)
|
1357
|
+
print(f"Generated {len(insts)} instances for {label} difficulty.")
|
1358
|
+
|
1359
|
+
# Save all generated task data to a single JSON file
|
1360
|
+
dataset_dir = Path(__file__).parent.parent / "dataset"
|
1361
|
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
1362
|
+
synthetic_mix_path = dataset_dir / "synthetic_mix.json"
|
1363
|
+
with open(synthetic_mix_path, "w") as f:
|
1364
|
+
json.dump(all_generated_task_data, f, indent=2)
|
1365
|
+
print(
|
1366
|
+
f"Saved all {len(all_generated_task_data)} generated task instances to {synthetic_mix_path}"
|
1367
|
+
)
|
1368
|
+
|
1369
|
+
# Now, run the evaluations using the generated tasks
|
1370
|
+
for label, num_agents, max_episode_turns in configs:
|
1371
|
+
print(
|
1372
|
+
f"\nRunning {num_agents} agents on {label} difficulty tasks (max_turns: {max_episode_turns})..."
|
1373
|
+
)
|
1374
|
+
current_difficulty_instances = all_tasks_for_eval[label]
|
1375
|
+
print(f"[DEBUG] About to run {len(current_difficulty_instances)} instances")
|
1376
|
+
|
1377
|
+
import time
|
1378
|
+
|
1379
|
+
start_time = time.time()
|
1380
|
+
print(
|
1381
|
+
f"[DEBUG] Starting asyncio.gather for {len(current_difficulty_instances)} episodes at {start_time}"
|
1382
|
+
)
|
1383
|
+
results = await asyncio.gather(
|
1384
|
+
*(run_episode_eval(inst, max_episode_turns) for inst in current_difficulty_instances)
|
1385
|
+
)
|
1386
|
+
end_time = time.time()
|
1387
|
+
print(f"[DEBUG] Completed asyncio.gather in {end_time - start_time:.2f} seconds")
|
1388
|
+
print(f"[DEBUG] Results: {results}")
|
1389
|
+
|
1390
|
+
num_successful_runs = sum(1 for r_success, _, _, _ in results if r_success)
|
1391
|
+
total_badges = sum(r_badges for _, r_badges, _, _ in results)
|
1392
|
+
total_rewards = sum(r_rewards for _, _, r_rewards, _ in results)
|
1393
|
+
avg_badges = total_badges / len(results) if results else 0.0
|
1394
|
+
avg_rewards = total_rewards / len(results) if results else 0.0
|
1395
|
+
|
1396
|
+
# Collect reward data for summary
|
1397
|
+
reward_counts = {}
|
1398
|
+
for inst_idx, (_, _, _, reward_history) in enumerate(results):
|
1399
|
+
# Get the reward history from the corresponding hist_cb
|
1400
|
+
# We need to access this from the episode run, so let's store it
|
1401
|
+
reward_counts[inst_idx] = reward_history
|
1402
|
+
|
1403
|
+
# Aggregate rewards across all instances for this difficulty
|
1404
|
+
for inst_idx, reward_history in reward_counts.items():
|
1405
|
+
for achievement in reward_history:
|
1406
|
+
for component in achievement["components"]:
|
1407
|
+
component_name = component["component"]
|
1408
|
+
if component_name not in all_reward_achievements:
|
1409
|
+
all_reward_achievements[component_name] = 0
|
1410
|
+
all_reward_achievements[component_name] += 1
|
1411
|
+
|
1412
|
+
table_rows.append(
|
1413
|
+
[
|
1414
|
+
label,
|
1415
|
+
f"{num_successful_runs}/{len(current_difficulty_instances)}",
|
1416
|
+
f"{avg_badges:.2f}",
|
1417
|
+
f"{avg_rewards:.1f}",
|
1418
|
+
]
|
1419
|
+
)
|
1420
|
+
print(
|
1421
|
+
f"Completed {label}: {num_successful_runs}/{len(current_difficulty_instances)} successful, Avg. Badges: {avg_badges:.2f}, Avg. Rewards: {avg_rewards:.1f}"
|
1422
|
+
)
|
1423
|
+
|
1424
|
+
print("\n--- Evaluation Summary ---")
|
1425
|
+
print(f"Model: {current_model_name_for_eval}, System: {actual_system_name}")
|
1426
|
+
print(
|
1427
|
+
tabulate(
|
1428
|
+
table_rows,
|
1429
|
+
headers=[
|
1430
|
+
"Difficulty",
|
1431
|
+
"Successful Runs",
|
1432
|
+
"Avg Badges Collected",
|
1433
|
+
"Avg Rewards Earned",
|
1434
|
+
],
|
1435
|
+
tablefmt="github",
|
1436
|
+
)
|
1437
|
+
)
|
1438
|
+
|
1439
|
+
# Display reward achievements summary
|
1440
|
+
if all_reward_achievements:
|
1441
|
+
print("\n--- Reward Achievements Summary ---")
|
1442
|
+
reward_summary_rows = []
|
1443
|
+
for reward_name, count in sorted(
|
1444
|
+
all_reward_achievements.items(), key=lambda x: x[1], reverse=True
|
1445
|
+
):
|
1446
|
+
reward_summary_rows.append([reward_name, count])
|
1447
|
+
|
1448
|
+
print(
|
1449
|
+
tabulate(
|
1450
|
+
reward_summary_rows,
|
1451
|
+
headers=["Reward Component", "Times Achieved"],
|
1452
|
+
tablefmt="github",
|
1453
|
+
)
|
1454
|
+
)
|
1455
|
+
print(f"\nTotal Unique Rewards Achieved: {len(all_reward_achievements)}")
|
1456
|
+
print(f"Total Reward Instances: {sum(all_reward_achievements.values())}")
|
1457
|
+
else:
|
1458
|
+
print("\n--- No Rewards Achieved ---")
|
1459
|
+
|
1460
|
+
|
1461
|
+
if __name__ == "__main__":
|
1462
|
+
# To run the test:
|
1463
|
+
# import tempfile
|
1464
|
+
# with tempfile.TemporaryDirectory() as tmpdir:
|
1465
|
+
# asyncio.run(test_react_agent_pokemon_red(Path(tmpdir)))
|
1466
|
+
|
1467
|
+
# better state management
|
1468
|
+
# To run the evaluation:
|
1469
|
+
asyncio.run(
|
1470
|
+
eval_react_pokemon_red(model_name="gpt-4.1-mini", max_turns=10, mode="state_and_screen")
|
1471
|
+
)
|