synth-ai 0.1.9__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +37 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.1.9.dist-info/METADATA +0 -37
- synth_ai-0.1.9.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,235 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import List, Optional, Any, Dict, Union
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
|
5
|
+
# Import logging configuration to suppress JAX debug messages
|
6
|
+
|
7
|
+
from .engine import (
|
8
|
+
PokemonRedEngine,
|
9
|
+
PokemonRedPrivateState,
|
10
|
+
PokemonRedPublicState,
|
11
|
+
PokemonRedEngineSnapshot,
|
12
|
+
)
|
13
|
+
from .taskset import PokemonRedTaskInstance, INSTANCE as DEFAULT_TASK_INSTANCE
|
14
|
+
from synth_ai.environments.environment.shared_engine import (
|
15
|
+
GetObservationCallable,
|
16
|
+
InternalObservation,
|
17
|
+
)
|
18
|
+
from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
|
19
|
+
from synth_ai.environments.stateful.core import StatefulEnvironment
|
20
|
+
from synth_ai.environments.environment.tools import (
|
21
|
+
AbstractTool,
|
22
|
+
EnvToolCall,
|
23
|
+
ToolResult,
|
24
|
+
TOOL_REGISTRY,
|
25
|
+
register_tool,
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
# Tool input schemas
|
30
|
+
class PressButtonInput(BaseModel):
|
31
|
+
button: str = Field(
|
32
|
+
..., description="Game Boy button: A, B, UP, DOWN, LEFT, RIGHT, START, SELECT"
|
33
|
+
)
|
34
|
+
frames: int = Field(1, description="Number of frames to hold the button")
|
35
|
+
|
36
|
+
|
37
|
+
# Tool definitions
|
38
|
+
class PressButtonTool(AbstractTool):
|
39
|
+
name = "press_button"
|
40
|
+
description = "Press a Game Boy button for the specified number of frames"
|
41
|
+
call_schema = PressButtonInput
|
42
|
+
result_schema = ToolResult
|
43
|
+
|
44
|
+
def __init__(self, engine: PokemonRedEngine):
|
45
|
+
self.engine = engine
|
46
|
+
|
47
|
+
async def __call__(self, call: EnvToolCall) -> ToolResult:
|
48
|
+
try:
|
49
|
+
validated_args = self.call_schema(**call.args)
|
50
|
+
action = {"button": validated_args.button, "frames": validated_args.frames}
|
51
|
+
priv_state, pub_state = await self.engine._step_engine(action)
|
52
|
+
return ToolResult(
|
53
|
+
ok=True,
|
54
|
+
payload={
|
55
|
+
"public": pub_state,
|
56
|
+
"private": priv_state,
|
57
|
+
},
|
58
|
+
)
|
59
|
+
except Exception as e:
|
60
|
+
# Get current state for error context
|
61
|
+
priv_state, pub_state = self.engine._create_states(reward=0.0)
|
62
|
+
return ToolResult(
|
63
|
+
ok=False,
|
64
|
+
error=str(e),
|
65
|
+
payload={"public": pub_state},
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
# Observation callable for Pokemon Red
|
70
|
+
class PokemonRedObservationCallable(GetObservationCallable):
|
71
|
+
async def get_observation(
|
72
|
+
self, pub: PokemonRedPublicState, priv: PokemonRedPrivateState
|
73
|
+
) -> InternalObservation:
|
74
|
+
"""Convert Pokemon Red states to agent observation"""
|
75
|
+
from .engine_helpers.state_extraction import (
|
76
|
+
get_badge_count,
|
77
|
+
format_position,
|
78
|
+
format_hp_status,
|
79
|
+
)
|
80
|
+
|
81
|
+
badge_count = get_badge_count(pub.badges)
|
82
|
+
position = format_position(pub.player_x, pub.player_y, pub.map_id)
|
83
|
+
hp_status = format_hp_status(pub.party_hp_current, pub.party_hp_max)
|
84
|
+
|
85
|
+
obs = {
|
86
|
+
"position": position,
|
87
|
+
"badges_earned": badge_count,
|
88
|
+
"badges_bitfield": pub.badges,
|
89
|
+
"hp_status": hp_status,
|
90
|
+
"party_level": pub.party_level,
|
91
|
+
"party_xp": pub.party_xp,
|
92
|
+
"in_battle": pub.in_battle,
|
93
|
+
"step_count": pub.step_count,
|
94
|
+
"reward_last_step": priv.reward_last_step,
|
95
|
+
"total_reward": priv.total_reward,
|
96
|
+
"terminated": priv.terminated,
|
97
|
+
}
|
98
|
+
|
99
|
+
if pub.error_info:
|
100
|
+
obs["error"] = pub.error_info
|
101
|
+
|
102
|
+
return obs
|
103
|
+
|
104
|
+
|
105
|
+
class PokemonRedEnvironment(StatefulEnvironment, ReproducibleEnvironment[PokemonRedEngine]):
|
106
|
+
"""Pokemon Red stateful game environment for AI agents"""
|
107
|
+
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
task_instance: Optional[PokemonRedTaskInstance] = None,
|
111
|
+
custom_step_obs: Optional[GetObservationCallable] = None,
|
112
|
+
custom_ckpt_obs: Optional[GetObservationCallable] = None,
|
113
|
+
):
|
114
|
+
self.name = "PokemonRed"
|
115
|
+
self.task_instance = task_instance or DEFAULT_TASK_INSTANCE
|
116
|
+
self.custom_step_observation_callable = custom_step_obs or PokemonRedObservationCallable()
|
117
|
+
self.custom_checkpoint_observation_callable = (
|
118
|
+
custom_ckpt_obs or PokemonRedObservationCallable()
|
119
|
+
)
|
120
|
+
self.engine = PokemonRedEngine(self.task_instance)
|
121
|
+
|
122
|
+
# Register tools
|
123
|
+
self._press_button_tool = PressButtonTool(self.engine)
|
124
|
+
if self._press_button_tool.name not in TOOL_REGISTRY:
|
125
|
+
register_tool(self._press_button_tool)
|
126
|
+
|
127
|
+
async def initialize(self) -> InternalObservation:
|
128
|
+
"""Initialize the Pokemon Red environment"""
|
129
|
+
priv, pub = await self.engine._reset_engine()
|
130
|
+
return await self._to_observation(priv, pub, self.custom_step_observation_callable)
|
131
|
+
|
132
|
+
async def terminate(self) -> InternalObservation:
|
133
|
+
"""Terminate the environment"""
|
134
|
+
priv, pub = self.engine._create_states(reward=0.0, terminated=True)
|
135
|
+
obs_dict = {
|
136
|
+
"terminated": True,
|
137
|
+
"message": "Pokemon Red environment terminated.",
|
138
|
+
}
|
139
|
+
return await self._to_observation(
|
140
|
+
priv, pub, self.custom_step_observation_callable, extra_obs=obs_dict
|
141
|
+
)
|
142
|
+
|
143
|
+
def validate_tool_calls(
|
144
|
+
self, tool_calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]]
|
145
|
+
) -> EnvToolCall:
|
146
|
+
"""Validate and normalize tool calls to single EnvToolCall"""
|
147
|
+
if isinstance(tool_calls, list):
|
148
|
+
if not tool_calls:
|
149
|
+
raise ValueError("Received empty list of tool calls.")
|
150
|
+
if isinstance(tool_calls[0], list):
|
151
|
+
if not tool_calls[0]:
|
152
|
+
raise ValueError("Received empty inner list of tool calls.")
|
153
|
+
agent_call = tool_calls[0][0]
|
154
|
+
else:
|
155
|
+
agent_call = tool_calls[0]
|
156
|
+
elif isinstance(tool_calls, EnvToolCall):
|
157
|
+
agent_call = tool_calls
|
158
|
+
else:
|
159
|
+
raise TypeError(f"Unexpected type for tool_calls: {type(tool_calls)}")
|
160
|
+
|
161
|
+
if not isinstance(agent_call, EnvToolCall):
|
162
|
+
raise TypeError(f"Processed call is not EnvToolCall: {type(agent_call)}")
|
163
|
+
if agent_call.tool != "press_button":
|
164
|
+
raise ValueError(f"Unknown tool: {agent_call.tool}. Expected 'press_button'.")
|
165
|
+
|
166
|
+
return agent_call
|
167
|
+
|
168
|
+
async def step(
|
169
|
+
self, tool_calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]]
|
170
|
+
) -> InternalObservation:
|
171
|
+
"""Execute one step in the Pokemon Red environment"""
|
172
|
+
agent_call = self.validate_tool_calls(tool_calls)
|
173
|
+
tool_result: ToolResult = await self._press_button_tool(agent_call)
|
174
|
+
|
175
|
+
payload_dict = tool_result.payload
|
176
|
+
if not tool_result.ok or not isinstance(payload_dict, dict):
|
177
|
+
# Fallback if tool execution failed
|
178
|
+
priv_state, pub_state = self.engine._create_states(reward=0.0)
|
179
|
+
if tool_result.error and hasattr(pub_state, "error_info"):
|
180
|
+
pub_state.error_info = tool_result.error
|
181
|
+
else:
|
182
|
+
# Extract states from successful tool execution - now they're dataclass objects
|
183
|
+
priv_state = payload_dict.get("private")
|
184
|
+
pub_state = payload_dict.get("public")
|
185
|
+
|
186
|
+
if priv_state is None or pub_state is None:
|
187
|
+
priv_state, pub_state = self.engine._create_states(reward=0.0)
|
188
|
+
if tool_result.error and hasattr(pub_state, "error_info"):
|
189
|
+
pub_state.error_info = tool_result.error
|
190
|
+
else:
|
191
|
+
# States are already dataclass objects, no need to reconstruct
|
192
|
+
if tool_result.error and hasattr(pub_state, "error_info"):
|
193
|
+
pub_state.error_info = tool_result.error
|
194
|
+
|
195
|
+
return await self._to_observation(
|
196
|
+
priv_state, pub_state, self.custom_step_observation_callable
|
197
|
+
)
|
198
|
+
|
199
|
+
async def checkpoint(self) -> InternalObservation:
|
200
|
+
"""Create a checkpoint of the current environment state"""
|
201
|
+
engine_snapshot: PokemonRedEngineSnapshot = await self.engine._serialize_engine()
|
202
|
+
priv, pub = self.engine._create_states(reward=0.0)
|
203
|
+
obs_data = await self._to_observation(
|
204
|
+
priv, pub, self.custom_checkpoint_observation_callable
|
205
|
+
)
|
206
|
+
if isinstance(obs_data, dict):
|
207
|
+
obs_data["engine_snapshot_data"] = engine_snapshot.model_dump()
|
208
|
+
return obs_data
|
209
|
+
|
210
|
+
async def _to_observation(
|
211
|
+
self,
|
212
|
+
priv: PokemonRedPrivateState,
|
213
|
+
pub: PokemonRedPublicState,
|
214
|
+
obs_cb: Optional[GetObservationCallable],
|
215
|
+
extra_obs: Optional[Dict[str, Any]] = None,
|
216
|
+
) -> InternalObservation:
|
217
|
+
"""Convert states to observation using the specified callback"""
|
218
|
+
active_obs_cb = obs_cb or PokemonRedObservationCallable()
|
219
|
+
observation = await active_obs_cb.get_observation(pub, priv)
|
220
|
+
if extra_obs and isinstance(observation, dict):
|
221
|
+
observation.update(extra_obs)
|
222
|
+
return observation
|
223
|
+
|
224
|
+
# ReproducibleEnvironment methods
|
225
|
+
async def _serialize_engine(self) -> PokemonRedEngineSnapshot:
|
226
|
+
return await self.engine._serialize_engine()
|
227
|
+
|
228
|
+
@classmethod
|
229
|
+
async def _deserialize_engine(
|
230
|
+
cls, snapshot: PokemonRedEngineSnapshot, task_instance: PokemonRedTaskInstance
|
231
|
+
) -> "PokemonRedEnvironment":
|
232
|
+
eng = await PokemonRedEngine._deserialize_engine(snapshot, task_instance)
|
233
|
+
env = cls(task_instance)
|
234
|
+
env.engine = eng
|
235
|
+
return env
|
@@ -0,0 +1,77 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from pathlib import Path
|
4
|
+
import uuid
|
5
|
+
from synth_ai.environments.tasks.core import (
|
6
|
+
Task,
|
7
|
+
TaskInstance,
|
8
|
+
Impetus,
|
9
|
+
Intent,
|
10
|
+
TaskInstanceMetadata,
|
11
|
+
)
|
12
|
+
|
13
|
+
# Define the main task for Pokemon Red
|
14
|
+
TASK = Task(
|
15
|
+
global_premises="You are playing Pokemon Red. Start in Pewter City with a level-10 Pikachu.",
|
16
|
+
global_constraints="No glitches or exploits. Play within normal game mechanics.",
|
17
|
+
global_objectives="Defeat Brock at the Pewter Gym to earn the Boulder Badge.",
|
18
|
+
shared_env_params={},
|
19
|
+
)
|
20
|
+
|
21
|
+
# Path to initial save state (would contain a save near Pewter Gym)
|
22
|
+
INITIAL_SNAPSHOT = Path(__file__).parent / "snapshots" / "pewter_start.state"
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class PokemonRedTaskInstance(TaskInstance):
|
27
|
+
"""Task instance for Pokemon Red challenges"""
|
28
|
+
|
29
|
+
async def serialize(self) -> dict:
|
30
|
+
"""Serialize the task instance to a dictionary"""
|
31
|
+
return {
|
32
|
+
"id": str(self.id),
|
33
|
+
"impetus": {"instructions": self.impetus.instructions},
|
34
|
+
"intent": {
|
35
|
+
"rubric": self.intent.rubric,
|
36
|
+
"gold_trajectories": None,
|
37
|
+
"gold_state_diff": self.intent.gold_state_diff,
|
38
|
+
},
|
39
|
+
"metadata": {},
|
40
|
+
"is_reproducible": self.is_reproducible,
|
41
|
+
"initial_engine_snapshot": str(self.initial_engine_snapshot)
|
42
|
+
if self.initial_engine_snapshot
|
43
|
+
else None,
|
44
|
+
}
|
45
|
+
|
46
|
+
@classmethod
|
47
|
+
async def deserialize(cls, data: dict) -> "PokemonRedTaskInstance":
|
48
|
+
"""Deserialize a task instance from a dictionary"""
|
49
|
+
return cls(
|
50
|
+
id=uuid.UUID(data["id"]),
|
51
|
+
impetus=Impetus(instructions=data["impetus"]["instructions"]),
|
52
|
+
intent=Intent(
|
53
|
+
rubric=data["intent"]["rubric"],
|
54
|
+
gold_trajectories=None,
|
55
|
+
gold_state_diff=data["intent"]["gold_state_diff"],
|
56
|
+
),
|
57
|
+
metadata=TaskInstanceMetadata(),
|
58
|
+
is_reproducible=data["is_reproducible"],
|
59
|
+
initial_engine_snapshot=None,
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
# Main task instance - beat Brock for Boulder Badge
|
64
|
+
INSTANCE = PokemonRedTaskInstance(
|
65
|
+
id=uuid.UUID("12345678-1234-5678-9abc-123456789abc"),
|
66
|
+
impetus=Impetus(
|
67
|
+
instructions="Navigate to Pewter Gym and defeat Brock to earn the Boulder Badge. Use strategic Pokemon battles and item management."
|
68
|
+
),
|
69
|
+
intent=Intent(
|
70
|
+
rubric="Successfully obtain the Boulder Badge by defeating Brock at Pewter Gym. Efficiency measured by minimal steps and strategic Pokemon usage.",
|
71
|
+
gold_trajectories=None,
|
72
|
+
gold_state_diff={"badges": 1},
|
73
|
+
),
|
74
|
+
metadata=TaskInstanceMetadata(),
|
75
|
+
is_reproducible=True,
|
76
|
+
initial_engine_snapshot=INITIAL_SNAPSHOT if INITIAL_SNAPSHOT.exists() else None,
|
77
|
+
)
|
@@ -0,0 +1,125 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Test script to verify red environment fixes.
|
4
|
+
Tests JAX logging suppression and error handling.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import logging
|
9
|
+
import sys
|
10
|
+
|
11
|
+
from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
|
12
|
+
from synth_ai.environments.examples.red.taskset import INSTANCE as POKEMON_TASK
|
13
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
14
|
+
|
15
|
+
|
16
|
+
class PressButtonCall(EnvToolCall):
|
17
|
+
"""Helper class for creating button press calls"""
|
18
|
+
|
19
|
+
def __init__(self, button: str, frames: int = 1):
|
20
|
+
super().__init__(tool="press_button", args={"button": button, "frames": frames})
|
21
|
+
|
22
|
+
|
23
|
+
async def test_environment_setup():
|
24
|
+
"""Test that the environment can be set up without errors."""
|
25
|
+
print("Testing Pokemon Red environment setup...")
|
26
|
+
|
27
|
+
try:
|
28
|
+
# Create environment instance
|
29
|
+
env = PokemonRedEnvironment(POKEMON_TASK)
|
30
|
+
print("✅ Environment created successfully")
|
31
|
+
|
32
|
+
# Try to initialize
|
33
|
+
obs = await env.initialize()
|
34
|
+
print("✅ Environment initialized successfully")
|
35
|
+
print(f"Initial observation keys: {list(obs.keys())}")
|
36
|
+
|
37
|
+
# Try a simple step
|
38
|
+
obs = await env.step(PressButtonCall("A"))
|
39
|
+
print("✅ Environment step executed successfully")
|
40
|
+
print(
|
41
|
+
f"Step observation: step_count={obs.get('step_count')}, terminated={obs.get('terminated')}"
|
42
|
+
)
|
43
|
+
|
44
|
+
# Terminate
|
45
|
+
final_obs = await env.terminate()
|
46
|
+
print("✅ Environment terminated successfully")
|
47
|
+
|
48
|
+
return True
|
49
|
+
|
50
|
+
except Exception as e:
|
51
|
+
print(f"❌ Failed to setup environment: {e}")
|
52
|
+
logging.exception("Failed to setup environment, aborting test")
|
53
|
+
return False
|
54
|
+
|
55
|
+
|
56
|
+
def test_logging_configuration():
|
57
|
+
"""Test that logging is properly configured."""
|
58
|
+
print("Testing logging configuration...")
|
59
|
+
|
60
|
+
# Check that JAX loggers are set to WARNING level
|
61
|
+
jax_loggers = [
|
62
|
+
"jax._src.cache_key",
|
63
|
+
"jax._src.compilation_cache",
|
64
|
+
"jax._src.compiler",
|
65
|
+
"jax._src.dispatch",
|
66
|
+
]
|
67
|
+
|
68
|
+
for logger_name in jax_loggers:
|
69
|
+
logger = logging.getLogger(logger_name)
|
70
|
+
if logger.level >= logging.WARNING:
|
71
|
+
print(f"✅ {logger_name} logger level: {logging.getLevelName(logger.level)}")
|
72
|
+
else:
|
73
|
+
print(
|
74
|
+
f"❌ {logger_name} logger level: {logging.getLevelName(logger.level)} (should be WARNING or higher)"
|
75
|
+
)
|
76
|
+
|
77
|
+
# Test that debug messages are suppressed
|
78
|
+
jax_logger = logging.getLogger("jax._src.cache_key")
|
79
|
+
jax_logger.debug("This debug message should not appear")
|
80
|
+
print("✅ JAX debug logging appears to be suppressed")
|
81
|
+
|
82
|
+
|
83
|
+
def test_safe_compare():
|
84
|
+
"""Test the safe comparison function."""
|
85
|
+
print("Testing safe comparison function...")
|
86
|
+
|
87
|
+
from synth_ai.environments.examples.red.config_logging import safe_compare
|
88
|
+
|
89
|
+
# Test cases
|
90
|
+
test_cases = [
|
91
|
+
("5", 3, ">", True), # String vs int
|
92
|
+
(5, "3", ">", True), # Int vs string
|
93
|
+
("abc", 5, ">", False), # Invalid string vs int
|
94
|
+
("5", "3", ">", True), # String vs string (numeric)
|
95
|
+
("abc", "def", ">", False), # String vs string (alphabetic)
|
96
|
+
(5, 3, ">", True), # Normal int comparison
|
97
|
+
]
|
98
|
+
|
99
|
+
for left, right, op, expected in test_cases:
|
100
|
+
result = safe_compare(left, right, op)
|
101
|
+
status = "✅" if result == expected else "❌"
|
102
|
+
print(f"{status} safe_compare({left}, {right}, '{op}') = {result} (expected {expected})")
|
103
|
+
|
104
|
+
|
105
|
+
async def main():
|
106
|
+
"""Main test function."""
|
107
|
+
print("Running Pokemon Red environment fixes test...\n")
|
108
|
+
|
109
|
+
# Test logging configuration
|
110
|
+
test_logging_configuration()
|
111
|
+
print()
|
112
|
+
|
113
|
+
# Test safe comparison
|
114
|
+
test_safe_compare()
|
115
|
+
print()
|
116
|
+
|
117
|
+
# Test environment setup
|
118
|
+
success = await test_environment_setup()
|
119
|
+
|
120
|
+
print(f"\nOverall test result: {'✅ PASSED' if success else '❌ FAILED'}")
|
121
|
+
return 0 if success else 1
|
122
|
+
|
123
|
+
|
124
|
+
if __name__ == "__main__":
|
125
|
+
sys.exit(asyncio.run(main()))
|
@@ -0,0 +1,148 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Mock test script to verify red environment fixes without ROM file.
|
4
|
+
Tests JAX logging suppression and error handling.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
import sys
|
9
|
+
from unittest.mock import Mock, patch
|
10
|
+
|
11
|
+
|
12
|
+
def test_logging_configuration():
|
13
|
+
"""Test that logging is properly configured."""
|
14
|
+
print("Testing logging configuration...")
|
15
|
+
|
16
|
+
# Import configuration to trigger setup
|
17
|
+
from synth_ai.environments.examples.red.config_logging import configure_logging
|
18
|
+
|
19
|
+
configure_logging()
|
20
|
+
|
21
|
+
# Check that JAX loggers are set to WARNING level
|
22
|
+
jax_loggers = [
|
23
|
+
"jax._src.cache_key",
|
24
|
+
"jax._src.compilation_cache",
|
25
|
+
"jax._src.compiler",
|
26
|
+
"jax._src.dispatch",
|
27
|
+
]
|
28
|
+
|
29
|
+
success = True
|
30
|
+
for logger_name in jax_loggers:
|
31
|
+
logger = logging.getLogger(logger_name)
|
32
|
+
if logger.level >= logging.WARNING:
|
33
|
+
print(f"✅ {logger_name} logger level: {logging.getLevelName(logger.level)}")
|
34
|
+
else:
|
35
|
+
print(
|
36
|
+
f"❌ {logger_name} logger level: {logging.getLevelName(logger.level)} (should be WARNING or higher)"
|
37
|
+
)
|
38
|
+
success = False
|
39
|
+
|
40
|
+
# Test that debug messages are suppressed
|
41
|
+
jax_logger = logging.getLogger("jax._src.cache_key")
|
42
|
+
jax_logger.debug("This debug message should not appear")
|
43
|
+
print("✅ JAX debug logging appears to be suppressed")
|
44
|
+
|
45
|
+
return success
|
46
|
+
|
47
|
+
|
48
|
+
def test_safe_compare():
|
49
|
+
"""Test the safe comparison function."""
|
50
|
+
print("Testing safe comparison function...")
|
51
|
+
|
52
|
+
from synth_ai.environments.examples.red.config_logging import safe_compare
|
53
|
+
|
54
|
+
# Test cases that previously would cause the string vs int error
|
55
|
+
test_cases = [
|
56
|
+
("5", 3, ">", True), # String vs int
|
57
|
+
(5, "3", ">", True), # Int vs string
|
58
|
+
("abc", 5, ">", False), # Invalid string vs int
|
59
|
+
("5", "3", ">", True), # String vs string (numeric)
|
60
|
+
("abc", "def", ">", False), # String vs string (alphabetic)
|
61
|
+
(5, 3, ">", True), # Normal int comparison
|
62
|
+
("10", 5, ">=", True), # String number >= int
|
63
|
+
(3, "10", "<=", True), # Int <= string number
|
64
|
+
]
|
65
|
+
|
66
|
+
success = True
|
67
|
+
for left, right, op, expected in test_cases:
|
68
|
+
result = safe_compare(left, right, op)
|
69
|
+
status = "✅" if result == expected else "❌"
|
70
|
+
print(f"{status} safe_compare({left}, {right}, '{op}') = {result} (expected {expected})")
|
71
|
+
if result != expected:
|
72
|
+
success = False
|
73
|
+
|
74
|
+
return success
|
75
|
+
|
76
|
+
|
77
|
+
def test_state_creation_error_handling():
|
78
|
+
"""Test that state creation handles type errors gracefully."""
|
79
|
+
print("Testing state creation error handling...")
|
80
|
+
|
81
|
+
from synth_ai.environments.examples.red.engine import PokemonRedEngine
|
82
|
+
from synth_ai.environments.examples.red.taskset import INSTANCE as POKEMON_TASK
|
83
|
+
|
84
|
+
try:
|
85
|
+
# Mock the PyBoy emulator to avoid ROM requirement
|
86
|
+
with patch("examples.red.engine.PyBoy") as mock_pyboy:
|
87
|
+
mock_emulator = Mock()
|
88
|
+
mock_pyboy.return_value = mock_emulator
|
89
|
+
|
90
|
+
# Create engine instance
|
91
|
+
engine = PokemonRedEngine(POKEMON_TASK)
|
92
|
+
|
93
|
+
# Mock extract_game_state to return problematic data that could cause comparison errors
|
94
|
+
with patch.object(engine, "_extract_current_state") as mock_extract:
|
95
|
+
# Test with string badges that could cause comparison error
|
96
|
+
mock_extract.return_value = {
|
97
|
+
"map_id": "1", # String instead of int
|
98
|
+
"player_x": "10",
|
99
|
+
"player_y": "20",
|
100
|
+
"badges": "abc", # Non-numeric string
|
101
|
+
"in_battle": "false", # String instead of bool
|
102
|
+
"party_level": "5",
|
103
|
+
"party_hp_current": "50",
|
104
|
+
"party_hp_max": "50",
|
105
|
+
"party_xp": "100",
|
106
|
+
}
|
107
|
+
|
108
|
+
# This should not crash due to our error handling
|
109
|
+
priv_state, pub_state = engine._create_states(0.0, False)
|
110
|
+
|
111
|
+
print("✅ State creation handles problematic data gracefully")
|
112
|
+
print(f"✅ Created states: badges={pub_state.badges}, map_id={pub_state.map_id}")
|
113
|
+
|
114
|
+
# Test with completely invalid data
|
115
|
+
mock_extract.side_effect = Exception("Memory read error")
|
116
|
+
priv_state, pub_state = engine._create_states(0.0, False)
|
117
|
+
print("✅ State creation handles extraction errors gracefully")
|
118
|
+
|
119
|
+
return True
|
120
|
+
|
121
|
+
except Exception as e:
|
122
|
+
print(f"❌ State creation error handling failed: {e}")
|
123
|
+
return False
|
124
|
+
|
125
|
+
|
126
|
+
def main():
|
127
|
+
"""Main test function."""
|
128
|
+
print("Running Pokemon Red environment fixes test (mock version)...\n")
|
129
|
+
|
130
|
+
# Test logging configuration
|
131
|
+
logging_ok = test_logging_configuration()
|
132
|
+
print()
|
133
|
+
|
134
|
+
# Test safe comparison
|
135
|
+
compare_ok = test_safe_compare()
|
136
|
+
print()
|
137
|
+
|
138
|
+
# Test error handling
|
139
|
+
error_handling_ok = test_state_creation_error_handling()
|
140
|
+
print()
|
141
|
+
|
142
|
+
success = logging_ok and compare_ok and error_handling_ok
|
143
|
+
print(f"Overall test result: {'✅ PASSED' if success else '❌ FAILED'}")
|
144
|
+
return 0 if success else 1
|
145
|
+
|
146
|
+
|
147
|
+
if __name__ == "__main__":
|
148
|
+
sys.exit(main())
|
@@ -0,0 +1 @@
|
|
1
|
+
# Unit tests for Pokemon Red environment
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import pytest
|
2
|
+
from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
|
3
|
+
from synth_ai.environments.examples.red.taskset import INSTANCE as POKEMON_TASK
|
4
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
5
|
+
|
6
|
+
|
7
|
+
class PressButtonCall(EnvToolCall):
|
8
|
+
"""Helper class for creating button press calls"""
|
9
|
+
|
10
|
+
def __init__(self, button: str, frames: int = 1):
|
11
|
+
super().__init__(tool="press_button", args={"button": button, "frames": frames})
|
12
|
+
|
13
|
+
|
14
|
+
@pytest.mark.asyncio
|
15
|
+
async def test_pokemon_red_basic():
|
16
|
+
"""Test basic Pokemon Red environment functionality"""
|
17
|
+
env = PokemonRedEnvironment(POKEMON_TASK)
|
18
|
+
|
19
|
+
# Initialize environment
|
20
|
+
obs = await env.initialize()
|
21
|
+
assert "position" in obs
|
22
|
+
assert "badges_earned" in obs
|
23
|
+
assert obs["badges_earned"] == 0 # Should start with no badges
|
24
|
+
|
25
|
+
# Test a few button presses
|
26
|
+
obs = await env.step(PressButtonCall("A"))
|
27
|
+
assert "step_count" in obs
|
28
|
+
assert obs["step_count"] == 1
|
29
|
+
|
30
|
+
obs = await env.step(PressButtonCall("RIGHT", 2))
|
31
|
+
assert obs["step_count"] == 2
|
32
|
+
|
33
|
+
# Test termination
|
34
|
+
final_obs = await env.terminate()
|
35
|
+
assert final_obs["terminated"] is True
|
36
|
+
|
37
|
+
|
38
|
+
@pytest.mark.asyncio
|
39
|
+
async def test_pokemon_red_multiple_actions():
|
40
|
+
"""Test sequence of actions in Pokemon Red"""
|
41
|
+
env = PokemonRedEnvironment(POKEMON_TASK)
|
42
|
+
|
43
|
+
obs = await env.initialize()
|
44
|
+
initial_reward = obs["total_reward"]
|
45
|
+
|
46
|
+
# Sequence of movements and actions
|
47
|
+
actions = [
|
48
|
+
PressButtonCall("RIGHT"),
|
49
|
+
PressButtonCall("UP"),
|
50
|
+
PressButtonCall("A"),
|
51
|
+
PressButtonCall("DOWN"),
|
52
|
+
PressButtonCall("B"),
|
53
|
+
]
|
54
|
+
|
55
|
+
for action in actions:
|
56
|
+
obs = await env.step(action)
|
57
|
+
assert "position" in obs
|
58
|
+
assert "hp_status" in obs
|
59
|
+
assert "party_level" in obs
|
60
|
+
|
61
|
+
# Should have accumulated some reward (mostly negative from step penalty)
|
62
|
+
assert obs["total_reward"] <= initial_reward # Step penalties
|
63
|
+
assert obs["step_count"] == len(actions)
|
64
|
+
|
65
|
+
|
66
|
+
@pytest.mark.asyncio
|
67
|
+
async def test_pokemon_red_checkpointing():
|
68
|
+
"""Test environment checkpointing functionality"""
|
69
|
+
env = PokemonRedEnvironment(POKEMON_TASK)
|
70
|
+
|
71
|
+
# Initialize and take some steps
|
72
|
+
await env.initialize()
|
73
|
+
await env.step(PressButtonCall("RIGHT"))
|
74
|
+
await env.step(PressButtonCall("A"))
|
75
|
+
|
76
|
+
# Create checkpoint
|
77
|
+
checkpoint_obs = await env.checkpoint()
|
78
|
+
assert "engine_snapshot_data" in checkpoint_obs
|
79
|
+
assert checkpoint_obs["step_count"] == 2
|
80
|
+
|
81
|
+
# Verify checkpoint contains expected data
|
82
|
+
snapshot_data = checkpoint_obs["engine_snapshot_data"]
|
83
|
+
assert "state_data" in snapshot_data
|
84
|
+
assert "total_reward" in snapshot_data
|
85
|
+
assert "step_count" in snapshot_data
|
86
|
+
|
87
|
+
|
88
|
+
@pytest.mark.asyncio
|
89
|
+
async def test_pokemon_red_invalid_button():
|
90
|
+
"""Test handling of invalid button inputs"""
|
91
|
+
env = PokemonRedEnvironment(POKEMON_TASK)
|
92
|
+
await env.initialize()
|
93
|
+
|
94
|
+
# Test with invalid button - should handle gracefully
|
95
|
+
obs = await env.step(PressButtonCall("INVALID_BUTTON"))
|
96
|
+
# Should still return valid observation even if action failed
|
97
|
+
assert "position" in obs
|