synth-ai 0.1.9__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +37 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.1.9.dist-info/METADATA +0 -37
- synth_ai-0.1.9.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,748 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import uuid
|
5
|
+
import pytest
|
6
|
+
import json
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Dict, Any, List, Optional, Deque
|
9
|
+
from pydantic import BaseModel, Field
|
10
|
+
from collections import deque
|
11
|
+
from synth_ai.zyk import LM
|
12
|
+
from synth_sdk.tracing.decorators import trace_event_async
|
13
|
+
from synth_sdk.tracing.abstractions import RewardSignal, Dataset, TrainingQuestion
|
14
|
+
from synth_sdk.tracing.utils import get_system_id
|
15
|
+
from synth_ai.environments.examples.sokoban.environment import (
|
16
|
+
SokobanEnvironment,
|
17
|
+
SokobanPublicState,
|
18
|
+
SokobanPrivateState,
|
19
|
+
)
|
20
|
+
from synth_ai.environments.examples.sokoban.engine import (
|
21
|
+
_grid_to_text,
|
22
|
+
ACTION_STRING_TO_INT,
|
23
|
+
)
|
24
|
+
from synth_ai.environments.environment.shared_engine import (
|
25
|
+
GetObservationCallable,
|
26
|
+
InternalObservation,
|
27
|
+
)
|
28
|
+
from synth_ai.environments.examples.sokoban.taskset import (
|
29
|
+
SokobanTaskInstance,
|
30
|
+
SokobanTaskInstanceMetadata,
|
31
|
+
)
|
32
|
+
from synth_ai.environments.tasks.core import Impetus, Intent
|
33
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
34
|
+
from dataclasses import dataclass
|
35
|
+
|
36
|
+
import logging
|
37
|
+
|
38
|
+
logging.disable(logging.CRITICAL)
|
39
|
+
|
40
|
+
|
41
|
+
@dataclass
|
42
|
+
class AgentDecisionRecord:
|
43
|
+
"""Record of agent's decision-making process including messages, tool calls, and results."""
|
44
|
+
|
45
|
+
action_int: int
|
46
|
+
input_messages: List[Dict[str, Any]]
|
47
|
+
output_messages: List[Dict[str, Any]]
|
48
|
+
tool_calls: List[Dict[str, Any]]
|
49
|
+
tool_results: List[Dict[str, Any]]
|
50
|
+
reasoning_text: str
|
51
|
+
model_name: str
|
52
|
+
raw_response: Any = None
|
53
|
+
|
54
|
+
|
55
|
+
# --- Helper function to format observation for LLM ---
|
56
|
+
def format_obs_for_llm_from_states(pub: SokobanPublicState, priv: SokobanPrivateState) -> str:
|
57
|
+
room_text = _grid_to_text(pub.room_state)
|
58
|
+
|
59
|
+
if pub.last_action_name.startswith("INVALID_ACTION_NO_CHANGE"):
|
60
|
+
# Return a message indicating the invalid action directly, along with key state info
|
61
|
+
return (
|
62
|
+
f"Previous action ({pub.last_action_name.split(': ')[-1]}) resulted in NO CHANGE to the board.\n"
|
63
|
+
f"{room_text}\n"
|
64
|
+
f"Boxes on Target: {pub.boxes_on_target} / {pub.num_boxes}\n"
|
65
|
+
f"Steps Taken: {pub.num_steps} / {pub.max_steps}\n"
|
66
|
+
f"Terminated: {priv.terminated}\n"
|
67
|
+
f"Last Reward: {priv.reward_last}"
|
68
|
+
)
|
69
|
+
|
70
|
+
# Default formatting for valid actions or initial state
|
71
|
+
return (
|
72
|
+
f"{room_text}\n"
|
73
|
+
f"Boxes on Target: {pub.boxes_on_target} / {pub.num_boxes}\n"
|
74
|
+
f"Steps Taken: {pub.num_steps} / {pub.max_steps}\n"
|
75
|
+
f"Terminated: {priv.terminated}\n"
|
76
|
+
f"Last Reward: {priv.reward_last}"
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
# ---------------------------------- custom observation callable ------------------------------ #
|
81
|
+
class HistoryObservationCallable(GetObservationCallable):
|
82
|
+
def __init__(self, max_history: int = 3):
|
83
|
+
self._hist: Deque[str] = deque(maxlen=max_history)
|
84
|
+
|
85
|
+
async def get_observation(
|
86
|
+
self, pub: SokobanPublicState, priv: SokobanPrivateState
|
87
|
+
) -> InternalObservation:
|
88
|
+
if pub is None or priv is None:
|
89
|
+
# This case might occur if env.terminate() is called and doesn't provide full states.
|
90
|
+
# For normal steps/reset, pub/priv should be valid.
|
91
|
+
# Consider how to handle this if it becomes an issue.
|
92
|
+
# For now, returning a dict that leads to an error or specific handling.
|
93
|
+
return {
|
94
|
+
"error": "Missing public or private state in get_observation",
|
95
|
+
"history_boards": list(self._hist),
|
96
|
+
} # type: ignore[return-value]
|
97
|
+
|
98
|
+
current_board_text = _grid_to_text(pub.room_state)
|
99
|
+
self._hist.append(current_board_text)
|
100
|
+
|
101
|
+
# Return public and private states along with history of board strings
|
102
|
+
return {"public": pub, "private": priv, "history_boards": list(self._hist)} # type: ignore[return-value]
|
103
|
+
|
104
|
+
|
105
|
+
# --- Pydantic Models for Tool Arguments ---
|
106
|
+
class SokobanInteractArgs(BaseModel):
|
107
|
+
actions_list: List[str] = Field(
|
108
|
+
description="List of actions to execute. Valid actions: move up, move down, move left, move right, push up, push down, push left, push right, no operation"
|
109
|
+
)
|
110
|
+
reasoning: str = Field(description="Reasoning for the chosen actions")
|
111
|
+
|
112
|
+
|
113
|
+
class TerminateArgs(BaseModel):
|
114
|
+
reasoning: str = Field(description="Reasoning for terminating the agent")
|
115
|
+
|
116
|
+
|
117
|
+
# --- tiny ReAct agent -------------------------------------------------- #
|
118
|
+
class Move(EnvToolCall):
|
119
|
+
def __init__(self, action: int):
|
120
|
+
super().__init__(tool="interact", args={"action": action})
|
121
|
+
|
122
|
+
|
123
|
+
class ReActAgent:
|
124
|
+
def __init__(self, llm, max_turns: int = 10):
|
125
|
+
self.llm, self.max_turns = llm, max_turns
|
126
|
+
self.history: List[Dict[str, Any]] = []
|
127
|
+
self.system_name: str = "sokoban-react-ex"
|
128
|
+
self.system_id: Any = get_system_id(self.system_name)
|
129
|
+
self.system_instance_id: str = str(uuid.uuid4())
|
130
|
+
self.last_obs_dict: Optional[Dict[str, Any]] = None
|
131
|
+
self.num_total_boxes: int = 0
|
132
|
+
|
133
|
+
self.tools = [
|
134
|
+
{
|
135
|
+
"type": "function",
|
136
|
+
"function": {
|
137
|
+
"name": "sokoban_interact",
|
138
|
+
"description": "Interacts with the Sokoban environment by proposing a single action.",
|
139
|
+
"parameters": SokobanInteractArgs.model_json_schema(),
|
140
|
+
},
|
141
|
+
},
|
142
|
+
{
|
143
|
+
"type": "function",
|
144
|
+
"function": {
|
145
|
+
"name": "terminate",
|
146
|
+
"description": "Terminates the agent's execution if the puzzle is solved or no further moves are required.",
|
147
|
+
"parameters": TerminateArgs.model_json_schema(),
|
148
|
+
},
|
149
|
+
},
|
150
|
+
]
|
151
|
+
|
152
|
+
def _format_history_for_prompt(self) -> str:
|
153
|
+
prompt_history = []
|
154
|
+
for entry in self.history:
|
155
|
+
if entry["type"] == "obs":
|
156
|
+
prompt_history.append(f"OBSERVATION:\n{entry['content']}")
|
157
|
+
elif entry["type"] == "tool_call":
|
158
|
+
args_str = json.dumps(entry["tool_arguments"])
|
159
|
+
prompt_history.append(
|
160
|
+
f"THOUGHT:\nI will call the tool `{entry['tool_name']}` with arguments: {args_str}\nACTION: (Tool call executed)"
|
161
|
+
)
|
162
|
+
elif entry["type"] == "tool_response":
|
163
|
+
prompt_history.append(
|
164
|
+
"TOOL_RESPONSE:\n(Action executed, new observation will follow if not terminal)"
|
165
|
+
)
|
166
|
+
return "\n".join(prompt_history)
|
167
|
+
|
168
|
+
@trace_event_async(event_type="react_agent_decide")
|
169
|
+
async def decide(self, obs: str) -> AgentDecisionRecord:
|
170
|
+
"""
|
171
|
+
Make a decision and return a complete record of the reasoning process.
|
172
|
+
"""
|
173
|
+
self.history.append({"type": "obs", "content": obs})
|
174
|
+
|
175
|
+
formatted_prompt_history = self._format_history_for_prompt()
|
176
|
+
user_prompt = f"{formatted_prompt_history}\n\nBased on the history above, particularly the last observation, what is your reasoning and which tool should you call next?"
|
177
|
+
|
178
|
+
system_prompt = (
|
179
|
+
"You are an agent playing Sokoban. Your goal is to push all boxes onto the target locations. "
|
180
|
+
"Review the history of observations, thoughts, and actions. "
|
181
|
+
"Based on this history, particularly the last observation, decide on the best next action. "
|
182
|
+
"You MUST call one of the two available tools: `sokoban_interact` or `terminate`.\n\n"
|
183
|
+
"Action Guide:\n"
|
184
|
+
"- Use 'move' actions (move up, move down, move left, move right) when moving to empty spaces\n"
|
185
|
+
"- Use 'push' actions (push up, push down, push left, push right) when pushing boxes onto targets\n"
|
186
|
+
"- Use 'no operation' to skip a turn\n\n"
|
187
|
+
"Please use the tools available to you. Do not attempt to include a tool call in your reasoning"
|
188
|
+
)
|
189
|
+
|
190
|
+
# Create input messages
|
191
|
+
input_messages = [
|
192
|
+
{"role": "system", "content": system_prompt},
|
193
|
+
{"role": "user", "content": user_prompt},
|
194
|
+
]
|
195
|
+
|
196
|
+
# Get response from LLM
|
197
|
+
response_obj = await self.llm.respond_async(
|
198
|
+
system_message=system_prompt, user_message=user_prompt, tools=self.tools
|
199
|
+
)
|
200
|
+
|
201
|
+
# Initialize record fields
|
202
|
+
tool_calls = []
|
203
|
+
tool_results = []
|
204
|
+
reasoning_text = ""
|
205
|
+
action_int = ACTION_STRING_TO_INT["no operation"] # Default fallback
|
206
|
+
|
207
|
+
# Handle cases where the model does **not** return a structured tool call.
|
208
|
+
if not getattr(response_obj, "tool_calls", None):
|
209
|
+
# Record fallback in history for later analysis
|
210
|
+
reasoning_text = "LLM failed to provide tool_calls; fallback action."
|
211
|
+
self.history.append(
|
212
|
+
{
|
213
|
+
"type": "tool_call",
|
214
|
+
"tool_name": "sokoban_interact",
|
215
|
+
"tool_arguments": {
|
216
|
+
"actions_list": ["move down"],
|
217
|
+
"reasoning": reasoning_text,
|
218
|
+
},
|
219
|
+
}
|
220
|
+
)
|
221
|
+
action_int = ACTION_STRING_TO_INT["move down"]
|
222
|
+
|
223
|
+
# Create fallback tool call record
|
224
|
+
tool_calls = [
|
225
|
+
{
|
226
|
+
"id": "fallback_0",
|
227
|
+
"type": "function",
|
228
|
+
"function": {
|
229
|
+
"name": "sokoban_interact",
|
230
|
+
"arguments": json.dumps(
|
231
|
+
{"actions_list": ["move down"], "reasoning": reasoning_text}
|
232
|
+
),
|
233
|
+
},
|
234
|
+
}
|
235
|
+
]
|
236
|
+
|
237
|
+
tool_results = [{"tool_call_id": "fallback_0", "content": "Fallback action: move down"}]
|
238
|
+
else:
|
239
|
+
# Process successful tool calls
|
240
|
+
response_tool_calls = None
|
241
|
+
|
242
|
+
try:
|
243
|
+
if hasattr(response_obj, "tool_calls") and response_obj.tool_calls:
|
244
|
+
response_tool_calls = response_obj.tool_calls
|
245
|
+
elif isinstance(response_obj, str):
|
246
|
+
try:
|
247
|
+
potential_tool_call_json = json.loads(response_obj)
|
248
|
+
if (
|
249
|
+
isinstance(potential_tool_call_json, dict)
|
250
|
+
and "tool_calls" in potential_tool_call_json
|
251
|
+
):
|
252
|
+
response_tool_calls = potential_tool_call_json["tool_calls"]
|
253
|
+
elif (
|
254
|
+
isinstance(potential_tool_call_json, list)
|
255
|
+
and len(potential_tool_call_json) > 0
|
256
|
+
and potential_tool_call_json[0].get("type") == "function"
|
257
|
+
):
|
258
|
+
response_tool_calls = potential_tool_call_json
|
259
|
+
except json.JSONDecodeError:
|
260
|
+
pass
|
261
|
+
|
262
|
+
if response_tool_calls and len(response_tool_calls) > 0:
|
263
|
+
tool_call_data = response_tool_calls[0]
|
264
|
+
|
265
|
+
tool_name = ""
|
266
|
+
tool_args_str = ""
|
267
|
+
|
268
|
+
if (
|
269
|
+
hasattr(tool_call_data, "function")
|
270
|
+
and hasattr(tool_call_data.function, "name")
|
271
|
+
and hasattr(tool_call_data.function, "arguments")
|
272
|
+
):
|
273
|
+
tool_name = tool_call_data.function.name
|
274
|
+
tool_args_str = tool_call_data.function.arguments
|
275
|
+
elif (
|
276
|
+
isinstance(tool_call_data, dict)
|
277
|
+
and "function" in tool_call_data
|
278
|
+
and isinstance(tool_call_data["function"], dict)
|
279
|
+
):
|
280
|
+
tool_name = tool_call_data["function"].get("name")
|
281
|
+
tool_args_str = tool_call_data["function"].get("arguments")
|
282
|
+
if not isinstance(tool_args_str, str):
|
283
|
+
tool_args_str = json.dumps(tool_args_str)
|
284
|
+
|
285
|
+
if tool_name and tool_args_str:
|
286
|
+
tool_arguments = json.loads(tool_args_str)
|
287
|
+
|
288
|
+
# Create proper tool call record
|
289
|
+
tool_calls = [
|
290
|
+
{
|
291
|
+
"id": f"call_{len(self.history)}",
|
292
|
+
"type": "function",
|
293
|
+
"function": {
|
294
|
+
"name": tool_name,
|
295
|
+
"arguments": tool_args_str,
|
296
|
+
},
|
297
|
+
}
|
298
|
+
]
|
299
|
+
|
300
|
+
# Record in history
|
301
|
+
self.history.append(
|
302
|
+
{
|
303
|
+
"type": "tool_call",
|
304
|
+
"tool_name": tool_name,
|
305
|
+
"tool_arguments": tool_arguments,
|
306
|
+
}
|
307
|
+
)
|
308
|
+
|
309
|
+
# Process the tool call
|
310
|
+
if tool_name == "sokoban_interact":
|
311
|
+
validated_args = SokobanInteractArgs(**tool_arguments)
|
312
|
+
reasoning_text = validated_args.reasoning
|
313
|
+
|
314
|
+
if validated_args.actions_list:
|
315
|
+
action_str = validated_args.actions_list[0]
|
316
|
+
action_int = ACTION_STRING_TO_INT.get(
|
317
|
+
action_str.lower(),
|
318
|
+
ACTION_STRING_TO_INT["no operation"],
|
319
|
+
)
|
320
|
+
|
321
|
+
tool_results = [
|
322
|
+
{
|
323
|
+
"tool_call_id": f"call_{len(self.history) - 1}",
|
324
|
+
"content": f"Executed action: {action_str}",
|
325
|
+
}
|
326
|
+
]
|
327
|
+
else:
|
328
|
+
tool_results = [
|
329
|
+
{
|
330
|
+
"tool_call_id": f"call_{len(self.history) - 1}",
|
331
|
+
"content": "No action specified, using no operation",
|
332
|
+
}
|
333
|
+
]
|
334
|
+
|
335
|
+
elif tool_name == "terminate":
|
336
|
+
validated_args = TerminateArgs(**tool_arguments)
|
337
|
+
reasoning_text = validated_args.reasoning
|
338
|
+
|
339
|
+
# Check if termination is valid
|
340
|
+
if self.last_obs_dict:
|
341
|
+
terminated_by_env = self.last_obs_dict.get("terminated", False)
|
342
|
+
boxes_on_target = int(self.last_obs_dict.get("boxes_on_target", 0))
|
343
|
+
is_solved_state = (
|
344
|
+
self.num_total_boxes > 0
|
345
|
+
and boxes_on_target == self.num_total_boxes
|
346
|
+
)
|
347
|
+
|
348
|
+
if terminated_by_env or is_solved_state:
|
349
|
+
action_int = -1 # Terminate
|
350
|
+
tool_results = [
|
351
|
+
{
|
352
|
+
"tool_call_id": f"call_{len(self.history) - 1}",
|
353
|
+
"content": "Termination accepted - puzzle solved or environment terminated",
|
354
|
+
}
|
355
|
+
]
|
356
|
+
else:
|
357
|
+
action_int = ACTION_STRING_TO_INT["no operation"]
|
358
|
+
tool_results = [
|
359
|
+
{
|
360
|
+
"tool_call_id": f"call_{len(self.history) - 1}",
|
361
|
+
"content": f"Termination rejected - puzzle not solved. Boxes on target: {boxes_on_target}/{self.num_total_boxes}",
|
362
|
+
}
|
363
|
+
]
|
364
|
+
else:
|
365
|
+
action_int = ACTION_STRING_TO_INT["no operation"]
|
366
|
+
tool_results = [
|
367
|
+
{
|
368
|
+
"tool_call_id": f"call_{len(self.history) - 1}",
|
369
|
+
"content": "Termination rejected - cannot verify puzzle state",
|
370
|
+
}
|
371
|
+
]
|
372
|
+
|
373
|
+
except Exception as e:
|
374
|
+
reasoning_text = f"Error processing LLM response: {str(e)}"
|
375
|
+
self.history.append({"type": "error", "content": reasoning_text})
|
376
|
+
action_int = ACTION_STRING_TO_INT["no operation"]
|
377
|
+
|
378
|
+
tool_calls = [
|
379
|
+
{
|
380
|
+
"id": f"error_{len(self.history)}",
|
381
|
+
"type": "function",
|
382
|
+
"function": {
|
383
|
+
"name": "sokoban_interact",
|
384
|
+
"arguments": json.dumps(
|
385
|
+
{
|
386
|
+
"actions_list": ["no operation"],
|
387
|
+
"reasoning": reasoning_text,
|
388
|
+
}
|
389
|
+
),
|
390
|
+
},
|
391
|
+
}
|
392
|
+
]
|
393
|
+
|
394
|
+
tool_results = [
|
395
|
+
{
|
396
|
+
"tool_call_id": f"error_{len(self.history)}",
|
397
|
+
"content": f"Error occurred: {str(e)}",
|
398
|
+
}
|
399
|
+
]
|
400
|
+
|
401
|
+
# Create output messages
|
402
|
+
output_messages = [
|
403
|
+
{"role": "assistant", "content": reasoning_text, "tool_calls": tool_calls}
|
404
|
+
]
|
405
|
+
|
406
|
+
# Add tool results as separate messages
|
407
|
+
for tool_result in tool_results:
|
408
|
+
output_messages.append(
|
409
|
+
{
|
410
|
+
"role": "tool",
|
411
|
+
"tool_call_id": tool_result["tool_call_id"],
|
412
|
+
"content": tool_result["content"],
|
413
|
+
}
|
414
|
+
)
|
415
|
+
|
416
|
+
# Create and return the decision record
|
417
|
+
return AgentDecisionRecord(
|
418
|
+
action_int=action_int,
|
419
|
+
input_messages=input_messages,
|
420
|
+
output_messages=output_messages,
|
421
|
+
tool_calls=tool_calls,
|
422
|
+
tool_results=tool_results,
|
423
|
+
reasoning_text=reasoning_text,
|
424
|
+
model_name=self.llm.model_name,
|
425
|
+
raw_response=response_obj,
|
426
|
+
)
|
427
|
+
|
428
|
+
|
429
|
+
# --- test ---------------------------------------------------------------- #
|
430
|
+
SIMPLE_SNAPSHOT: Dict[str, Any] = {
|
431
|
+
"dim_room": [4, 4],
|
432
|
+
"room_fixed": [
|
433
|
+
[0, 0, 0, 0],
|
434
|
+
[0, 1, 2, 1], # target at (1,2)
|
435
|
+
[0, 1, 1, 1],
|
436
|
+
[0, 0, 0, 0],
|
437
|
+
],
|
438
|
+
"room_state": [
|
439
|
+
[0, 0, 0, 0],
|
440
|
+
[0, 1, 1, 1],
|
441
|
+
[0, 1, 4, 1], # box at (2,2)
|
442
|
+
[0, 5, 1, 1], # player at (3,1)
|
443
|
+
],
|
444
|
+
"boxes_on_target": 0,
|
445
|
+
"max_steps": 10,
|
446
|
+
"num_boxes": 1,
|
447
|
+
}
|
448
|
+
|
449
|
+
|
450
|
+
@pytest.mark.asyncio
|
451
|
+
async def test_react_agent_sokoban(tmp_path: Path):
|
452
|
+
inst = SokobanTaskInstance(
|
453
|
+
id=uuid.uuid4(),
|
454
|
+
impetus=Impetus(instructions="solve"),
|
455
|
+
intent=Intent(rubric={}, gold_trajectories=None, gold_state_diff={}),
|
456
|
+
metadata=SokobanTaskInstanceMetadata("easy", 1, (4, 4), 10, -1, -1, "unit"),
|
457
|
+
is_reproducible=True,
|
458
|
+
initial_engine_snapshot=SIMPLE_SNAPSHOT,
|
459
|
+
)
|
460
|
+
hist_cb = HistoryObservationCallable(max_history=3)
|
461
|
+
env = SokobanEnvironment(inst, custom_step_obs=hist_cb)
|
462
|
+
env.engine.package_sokoban_env.render_mode = "raw" # type: ignore[attr-defined]
|
463
|
+
|
464
|
+
llm = LM(model_name="gpt-4.1-nano", formatting_model_name="gpt-4.1-nano", temperature=0.0)
|
465
|
+
agent = ReActAgent(llm)
|
466
|
+
|
467
|
+
async def run_episode():
|
468
|
+
obs_payload = await env.initialize()
|
469
|
+
|
470
|
+
# Ensure payload is not an error structure from callable
|
471
|
+
if "error" in obs_payload:
|
472
|
+
return False # Or handle error appropriately
|
473
|
+
|
474
|
+
agent.last_obs_dict = {
|
475
|
+
"terminated": obs_payload["private"].terminated,
|
476
|
+
"boxes_on_target": obs_payload["public"].boxes_on_target,
|
477
|
+
}
|
478
|
+
agent.num_total_boxes = obs_payload["public"].num_boxes
|
479
|
+
current_input_to_agent = format_obs_for_llm_from_states(
|
480
|
+
obs_payload["public"], obs_payload["private"]
|
481
|
+
)
|
482
|
+
|
483
|
+
for turn in range(agent.max_turns):
|
484
|
+
decision_record = await agent.decide(current_input_to_agent)
|
485
|
+
act_idx = decision_record.action_int
|
486
|
+
|
487
|
+
if act_idx == -1:
|
488
|
+
obs_payload_next = obs_payload
|
489
|
+
break
|
490
|
+
|
491
|
+
step_result = await env.step([[Move(act_idx)]])
|
492
|
+
|
493
|
+
obs_payload_next = step_result
|
494
|
+
if "error" in obs_payload_next:
|
495
|
+
break # Or handle error appropriately
|
496
|
+
|
497
|
+
agent.last_obs_dict = {
|
498
|
+
"terminated": obs_payload_next["private"].terminated,
|
499
|
+
"boxes_on_target": obs_payload_next["public"].boxes_on_target,
|
500
|
+
}
|
501
|
+
# agent.num_total_boxes is assumed constant after initialization
|
502
|
+
|
503
|
+
agent.history.append({"type": "tool_response", "content": "Action executed"})
|
504
|
+
|
505
|
+
current_input_to_agent = format_obs_for_llm_from_states(
|
506
|
+
obs_payload_next["public"], obs_payload_next["private"]
|
507
|
+
)
|
508
|
+
|
509
|
+
# obs_payload_next["history_boards"] already contains the history *including* the most recent board
|
510
|
+
# due to how _hist.append() and list(self._hist) is structured in the callable now.
|
511
|
+
# So, history_boards is a list of the N most recent board states, newest last.
|
512
|
+
displayed_boards = obs_payload_next["history_boards"]
|
513
|
+
# for i, board_text in enumerate(displayed_boards):
|
514
|
+
# t-0 is the newest, t-(N-1) is the oldest in the deque
|
515
|
+
|
516
|
+
obs_payload = obs_payload_next
|
517
|
+
|
518
|
+
if obs_payload_next["private"].terminated:
|
519
|
+
break
|
520
|
+
|
521
|
+
if "obs_payload_next" not in locals():
|
522
|
+
obs_payload_next = obs_payload
|
523
|
+
|
524
|
+
if "error" in obs_payload_next:
|
525
|
+
return False # Indicate failure
|
526
|
+
|
527
|
+
return obs_payload_next["private"].terminated
|
528
|
+
|
529
|
+
solved_status = await run_episode()
|
530
|
+
dataset = Dataset(
|
531
|
+
questions=[TrainingQuestion(id="sokoban_ep", intent="solve", criteria="solved")],
|
532
|
+
reward_signals=[
|
533
|
+
RewardSignal(
|
534
|
+
question_id="sokoban_ep",
|
535
|
+
system_instance_id=agent.system_instance_id,
|
536
|
+
reward=1 if solved_status else 0,
|
537
|
+
annotation=json.dumps({"agent_history": agent.history}),
|
538
|
+
)
|
539
|
+
],
|
540
|
+
)
|
541
|
+
# upload(dataset=dataset)
|
542
|
+
# assert solved_status
|
543
|
+
|
544
|
+
# Print the agent's final reward using checkpoint observation
|
545
|
+
# final_obs = await env.checkpoint()
|
546
|
+
# if isinstance(final_obs, dict):
|
547
|
+
# else:
|
548
|
+
|
549
|
+
|
550
|
+
async def eval_react_sokoban(
|
551
|
+
model_name: str = "gpt-4.1-nano", # Default will be overridden by caller
|
552
|
+
formatting_model_name: str = "gpt-4.1-nano", # Default will be overridden by caller
|
553
|
+
) -> List[Dict[str, Any]]:
|
554
|
+
"""
|
555
|
+
Run ReAct agents on Sokoban instances of different difficulties for a given model,
|
556
|
+
and returns a list of dictionaries containing aggregated results for each mode.
|
557
|
+
"""
|
558
|
+
from synth_ai.environments.examples.sokoban.engine_helpers.room_utils import (
|
559
|
+
generate_room,
|
560
|
+
get_shortest_action_path,
|
561
|
+
)
|
562
|
+
import asyncio
|
563
|
+
import uuid
|
564
|
+
|
565
|
+
current_model_name_for_eval = model_name # Use passed-in model name
|
566
|
+
|
567
|
+
_temp_llm_for_names = LM(
|
568
|
+
model_name=current_model_name_for_eval,
|
569
|
+
formatting_model_name=formatting_model_name, # Use passed-in formatting model name
|
570
|
+
temperature=0.0,
|
571
|
+
)
|
572
|
+
_temp_agent_for_names = ReActAgent(_temp_llm_for_names)
|
573
|
+
actual_system_name = _temp_agent_for_names.system_name
|
574
|
+
|
575
|
+
# Helper to run a single episode (remains largely the same, but uses current_model_name_for_eval)
|
576
|
+
async def run_episode(inst) -> bool:
|
577
|
+
"""Run a single agent/instance episode and return True on success."""
|
578
|
+
hist_cb = HistoryObservationCallable(max_history=3)
|
579
|
+
env = SokobanEnvironment(inst, custom_step_obs=hist_cb)
|
580
|
+
env.engine.package_sokoban_env.render_mode = "raw" # type: ignore[attr-defined]
|
581
|
+
llm_for_episode = LM(
|
582
|
+
model_name=current_model_name_for_eval, # Uses the model for this eval_react_sokoban call
|
583
|
+
formatting_model_name=formatting_model_name, # Uses the formatting model for this call
|
584
|
+
temperature=0.0,
|
585
|
+
)
|
586
|
+
agent = ReActAgent(llm_for_episode)
|
587
|
+
|
588
|
+
obs = await env.initialize()
|
589
|
+
agent.last_obs_dict = {
|
590
|
+
"terminated": obs["private"].terminated,
|
591
|
+
"boxes_on_target": obs["public"].boxes_on_target,
|
592
|
+
}
|
593
|
+
agent.num_total_boxes = obs["public"].num_boxes
|
594
|
+
prompt_obs = format_obs_for_llm_from_states(obs["public"], obs["private"])
|
595
|
+
|
596
|
+
for _ in range(agent.max_turns):
|
597
|
+
decision_record = await agent.decide(prompt_obs)
|
598
|
+
act_idx = decision_record.action_int
|
599
|
+
if act_idx == -1: # agent terminated
|
600
|
+
break
|
601
|
+
obs = await env.step([{"tool": "interact", "args": {"action": act_idx}}])
|
602
|
+
if "error" in obs: # safety guard
|
603
|
+
return False
|
604
|
+
agent.last_obs_dict = {
|
605
|
+
"terminated": obs["private"].terminated,
|
606
|
+
"boxes_on_target": obs["public"].boxes_on_target,
|
607
|
+
}
|
608
|
+
agent.history.append({"type": "tool_response", "content": "Action executed"})
|
609
|
+
prompt_obs = format_obs_for_llm_from_states(obs["public"], obs["private"])
|
610
|
+
if obs["private"].terminated: # env solved
|
611
|
+
break
|
612
|
+
return obs["private"].terminated
|
613
|
+
|
614
|
+
# Instance factory (remains the same)
|
615
|
+
async def make_instances(label: str, target_len: int, n: int = 3):
|
616
|
+
instances = []
|
617
|
+
seed = 0
|
618
|
+
while len(instances) < n:
|
619
|
+
room_structure, room_state, _, _ = generate_room(
|
620
|
+
dim=(5, 5),
|
621
|
+
initial_seed=seed,
|
622
|
+
num_boxes=1,
|
623
|
+
search_depth=max(10, target_len + 2),
|
624
|
+
)
|
625
|
+
path = get_shortest_action_path(room_structure, room_state, MAX_DEPTH=20)
|
626
|
+
if len(path) == target_len:
|
627
|
+
inst = SokobanTaskInstance(
|
628
|
+
id=uuid.uuid4(),
|
629
|
+
impetus=Impetus(instructions="Solve"),
|
630
|
+
intent=Intent(rubric={}, gold_trajectories=None, gold_state_diff={}),
|
631
|
+
metadata=SokobanTaskInstanceMetadata(
|
632
|
+
label, 1, (5, 5), 20, len(path), seed, f"len={target_len}"
|
633
|
+
),
|
634
|
+
is_reproducible=True,
|
635
|
+
initial_engine_snapshot={
|
636
|
+
"dim_room": (5, 5),
|
637
|
+
"room_fixed": room_structure,
|
638
|
+
"room_state": room_state,
|
639
|
+
"boxes_on_target": 0,
|
640
|
+
"max_steps": 20,
|
641
|
+
"num_boxes": 1,
|
642
|
+
},
|
643
|
+
)
|
644
|
+
instances.append(inst)
|
645
|
+
seed += 1
|
646
|
+
return instances
|
647
|
+
|
648
|
+
# Evaluation logic
|
649
|
+
configs = [("ultra-easy", 1), ("easy", 3), ("medium", 5)]
|
650
|
+
results_for_this_model = [] # Store list of dicts for this model's run
|
651
|
+
|
652
|
+
print(
|
653
|
+
f"\nStarting Sokoban ReAct Agent Evaluation for Model: {current_model_name_for_eval}, System: {actual_system_name}"
|
654
|
+
)
|
655
|
+
|
656
|
+
for label, step_len in configs:
|
657
|
+
print(f" Processing difficulty: {label} for model {current_model_name_for_eval}...")
|
658
|
+
insts = await make_instances(label, step_len, n=3) # 3 instances per difficulty
|
659
|
+
solved_statuses = await asyncio.gather(*(run_episode(i) for i in insts))
|
660
|
+
num_solved = sum(solved_statuses)
|
661
|
+
rate = num_solved / len(insts) if insts else 0.0
|
662
|
+
results_for_this_model.append(
|
663
|
+
{
|
664
|
+
"Model": current_model_name_for_eval,
|
665
|
+
"Difficulty": label,
|
666
|
+
"Solved": f"{num_solved}/{len(insts)}",
|
667
|
+
"Success Rate": f"{rate:.0%}",
|
668
|
+
}
|
669
|
+
)
|
670
|
+
print(
|
671
|
+
f" Completed {label} for model {current_model_name_for_eval}: {num_solved}/{len(insts)} solved ({rate:.0%})"
|
672
|
+
)
|
673
|
+
|
674
|
+
return results_for_this_model
|
675
|
+
|
676
|
+
|
677
|
+
if __name__ == "__main__":
|
678
|
+
# asyncio.run(eval_react_sokoban()) # Old way of running a single model
|
679
|
+
|
680
|
+
async def run_all_sokoban_evals_parallel():
|
681
|
+
models_to_evaluate = [
|
682
|
+
{"model_name": "gpt-4.1-nano", "formatting_model_name": "gpt-4.1-nano"},
|
683
|
+
{"model_name": "gpt-4.1", "formatting_model_name": "gpt-4.1"},
|
684
|
+
{
|
685
|
+
"model_name": "o4-mini",
|
686
|
+
"formatting_model_name": "o4-mini",
|
687
|
+
}, # Assuming o4-mini uses itself for formatting
|
688
|
+
]
|
689
|
+
|
690
|
+
print("Starting parallel Sokoban evaluation for all specified models...")
|
691
|
+
|
692
|
+
# eval_react_sokoban returns List[Dict[str, Any]]
|
693
|
+
# all_model_results will be a List[List[Dict[str, Any]]]
|
694
|
+
all_model_results = await asyncio.gather(
|
695
|
+
*[
|
696
|
+
eval_react_sokoban(
|
697
|
+
model_name=model_config["model_name"],
|
698
|
+
formatting_model_name=model_config["formatting_model_name"],
|
699
|
+
)
|
700
|
+
for model_config in models_to_evaluate
|
701
|
+
]
|
702
|
+
)
|
703
|
+
|
704
|
+
print("\n=== ALL SOKOBAN EVALUATIONS COMPLETED ===")
|
705
|
+
|
706
|
+
# Flatten the list of lists into a single list of dictionaries
|
707
|
+
combined_sokoban_results = []
|
708
|
+
for model_result_list in all_model_results:
|
709
|
+
combined_sokoban_results.extend(model_result_list)
|
710
|
+
|
711
|
+
print("\n--- Combined Sokoban Evaluation Summary Table ---")
|
712
|
+
from tabulate import tabulate # Ensure tabulate is imported
|
713
|
+
|
714
|
+
if combined_sokoban_results:
|
715
|
+
# Headers="keys" will use the dictionary keys as headers
|
716
|
+
print(
|
717
|
+
tabulate(
|
718
|
+
combined_sokoban_results,
|
719
|
+
headers="keys",
|
720
|
+
tablefmt="github",
|
721
|
+
)
|
722
|
+
)
|
723
|
+
else:
|
724
|
+
print("No Sokoban evaluation data to display.")
|
725
|
+
|
726
|
+
asyncio.run(run_all_sokoban_evals_parallel())
|
727
|
+
|
728
|
+
# Model: o4-mini, System: sokoban-react-ex
|
729
|
+
# | Difficulty | Solved | Success Rate |
|
730
|
+
# |--------------|----------|----------------|
|
731
|
+
# | ultra-easy | 3/3 | 100% |
|
732
|
+
# | easy | 3/3 | 100% |
|
733
|
+
# | medium | 3/3 | 100% |
|
734
|
+
|
735
|
+
|
736
|
+
# Model: gpt-4.1, System: sokoban-react-ex
|
737
|
+
# | Difficulty | Solved | Success Rate |
|
738
|
+
# |--------------|----------|----------------|
|
739
|
+
# | ultra-easy | 1/3 | 33% |
|
740
|
+
# | easy | 0/3 | 0% |
|
741
|
+
# | medium | 0/3 | 0% |
|
742
|
+
|
743
|
+
# Model: gpt-4.1-nano, System: sokoban-react-ex
|
744
|
+
# | Difficulty | Solved | Success Rate |
|
745
|
+
# |--------------|----------|----------------|
|
746
|
+
# | ultra-easy | 0/3 | 0% |
|
747
|
+
# | easy | 0/3 | 0% |
|
748
|
+
# | medium | 0/3 | 0% |
|