synth-ai 0.1.9__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +37 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.1.9.dist-info/METADATA +0 -37
- synth_ai-0.1.9.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,47 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Run Gemini 1.5 Flash evaluation on MiniGrid tasks.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import asyncio
|
7
|
+
from eval_framework import run_minigrid_eval, get_success_rate
|
8
|
+
|
9
|
+
|
10
|
+
async def run_gemini_evaluation():
|
11
|
+
"""Run Gemini 1.5 Flash on 5 instances and display results."""
|
12
|
+
|
13
|
+
print("🚀 Running Gemini 1.5 Flash MiniGrid Evaluation")
|
14
|
+
print("=" * 60)
|
15
|
+
print("Model: gemini-1.5-flash-latest")
|
16
|
+
print("Instances: 5 trajectories per condition")
|
17
|
+
print("Tasks: Empty-5x5-v0, DoorKey-5x5-v0")
|
18
|
+
print("Difficulties: easy, medium")
|
19
|
+
print("=" * 60)
|
20
|
+
|
21
|
+
# Run the evaluation
|
22
|
+
report = await run_minigrid_eval(
|
23
|
+
model_names=["gemini-1.5-flash-latest"],
|
24
|
+
difficulties=["easy", "medium"],
|
25
|
+
task_types=["Empty-5x5-v0", "DoorKey-5x5-v0"], # Start with simple tasks
|
26
|
+
num_trajectories=5, # 5 instances as requested
|
27
|
+
max_turns=30,
|
28
|
+
)
|
29
|
+
|
30
|
+
print("\n" + "=" * 60)
|
31
|
+
print("📊 QUICK SUCCESS RATE SUMMARY")
|
32
|
+
print("=" * 60)
|
33
|
+
|
34
|
+
# Extract success rates for quick reference
|
35
|
+
for difficulty in ["easy", "medium"]:
|
36
|
+
success_rate = get_success_rate(report, "gemini-1.5-flash-latest", difficulty)
|
37
|
+
print(f"Gemini 1.5 Flash ({difficulty:6}): {success_rate:5.1f}%")
|
38
|
+
|
39
|
+
overall_success = get_success_rate(report, "gemini-1.5-flash-latest")
|
40
|
+
print(f"Overall Average: {overall_success:5.1f}%")
|
41
|
+
|
42
|
+
return report
|
43
|
+
|
44
|
+
|
45
|
+
if __name__ == "__main__":
|
46
|
+
# Run the evaluation
|
47
|
+
asyncio.run(run_gemini_evaluation())
|
@@ -0,0 +1,562 @@
|
|
1
|
+
"""ReAct agent demo for MiniGrid environment."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import json
|
5
|
+
import os
|
6
|
+
from datetime import datetime
|
7
|
+
from typing import Dict, Any, List, Optional
|
8
|
+
from pydantic import BaseModel, Field
|
9
|
+
import uuid
|
10
|
+
|
11
|
+
# Import SynthAI LM and BaseTool
|
12
|
+
from synth_ai.zyk import LM
|
13
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
14
|
+
from synth_sdk.tracing.decorators import trace_event_async
|
15
|
+
import sys
|
16
|
+
import os
|
17
|
+
|
18
|
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
19
|
+
|
20
|
+
from synth_ai.environments.examples.minigrid.environment import MiniGridEnvironment
|
21
|
+
from synth_ai.environments.examples.minigrid.taskset import (
|
22
|
+
create_minigrid_taskset,
|
23
|
+
DEFAULT_MINIGRID_TASK,
|
24
|
+
)
|
25
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
26
|
+
|
27
|
+
|
28
|
+
# --- Pydantic Models for Tool Arguments ---
|
29
|
+
class MiniGridActArgs(BaseModel):
|
30
|
+
"""Arguments for MiniGrid action."""
|
31
|
+
|
32
|
+
action: str = Field(
|
33
|
+
description="The action to take. Must be one of: 'left', 'right', 'forward', 'pickup', 'drop', 'toggle', 'done'"
|
34
|
+
)
|
35
|
+
reasoning: str = Field(description="A brief explanation of why this action was chosen")
|
36
|
+
|
37
|
+
|
38
|
+
class TerminateArgs(BaseModel):
|
39
|
+
"""Arguments for termination."""
|
40
|
+
|
41
|
+
reason: str = Field(description="Reason for termination")
|
42
|
+
|
43
|
+
|
44
|
+
# --- Tool Definitions ---
|
45
|
+
|
46
|
+
|
47
|
+
class MiniGridActTool(BaseTool):
|
48
|
+
"""Tool for performing an action in MiniGrid."""
|
49
|
+
|
50
|
+
name: str = "minigrid_act"
|
51
|
+
arguments: type[BaseModel] = MiniGridActArgs
|
52
|
+
description: str = "Perform an action in the MiniGrid environment."
|
53
|
+
|
54
|
+
|
55
|
+
class TerminateTool(BaseTool):
|
56
|
+
"""Tool to terminate the episode."""
|
57
|
+
|
58
|
+
name: str = "terminate"
|
59
|
+
arguments: type[BaseModel] = TerminateArgs
|
60
|
+
description: str = "End the episode when finished or no progress can be made."
|
61
|
+
|
62
|
+
|
63
|
+
# --- ReAct Agent ---
|
64
|
+
class MiniGridReActAgent:
|
65
|
+
"""ReAct agent for MiniGrid environments."""
|
66
|
+
|
67
|
+
def __init__(self, llm: LM, max_turns: int = 30, verbose: bool = False):
|
68
|
+
self.llm = llm
|
69
|
+
self.max_turns = max_turns
|
70
|
+
self.verbose = verbose
|
71
|
+
self.history = []
|
72
|
+
self.debug_log = [] # Store all prompts and responses for debugging
|
73
|
+
self.system_name: str = "minigrid-react-agent" # Required for synth-sdk tracing
|
74
|
+
self.system_instance_id: str = str(uuid.uuid4()) # Required for synth-sdk tracing
|
75
|
+
|
76
|
+
# Available tools
|
77
|
+
self.tools = [MiniGridActTool(), TerminateTool()]
|
78
|
+
|
79
|
+
def _format_observation(self, obs: Dict[str, Any]) -> str:
|
80
|
+
"""Format observation for LLM."""
|
81
|
+
if "observation" in obs:
|
82
|
+
return obs["observation"]
|
83
|
+
|
84
|
+
# Fallback formatting
|
85
|
+
parts = []
|
86
|
+
if "mission" in obs:
|
87
|
+
parts.append(f"Mission: {obs['mission']}")
|
88
|
+
if "terminated" in obs:
|
89
|
+
parts.append(f"Terminated: {obs['terminated']}")
|
90
|
+
if "reward_last" in obs:
|
91
|
+
parts.append(f"Last Reward: {obs['reward_last']:.3f}")
|
92
|
+
if "total_reward" in obs:
|
93
|
+
parts.append(f"Total Reward: {obs['total_reward']:.3f}")
|
94
|
+
|
95
|
+
return "\n".join(parts)
|
96
|
+
|
97
|
+
@trace_event_async(event_type="minigrid_react_decide")
|
98
|
+
async def decide(self, obs: str, task_description: str, turn: int) -> Dict[str, Any]:
|
99
|
+
"""Get LLM decision for next action."""
|
100
|
+
system_message = f"""You are playing a MiniGrid environment. {task_description}
|
101
|
+
|
102
|
+
CRITICAL UNDERSTANDING OF THE GRID:
|
103
|
+
|
104
|
+
1. HOW TO READ THE GRID:
|
105
|
+
- The grid shows a top-down view of a small world
|
106
|
+
- Your position is shown by an arrow: → ↓ ← ↑
|
107
|
+
- The arrow shows both WHERE you are and WHICH DIRECTION you're facing
|
108
|
+
|
109
|
+
2. GRID SYMBOLS:
|
110
|
+
- → ↓ ← ↑ = YOU (the arrow points in the direction you're facing)
|
111
|
+
- # = wall (CANNOT move through these)
|
112
|
+
- . = empty space (CAN move through these)
|
113
|
+
- G = goal (your target - GET HERE to win!)
|
114
|
+
- L = lava (AVOID - stepping on this ends the game)
|
115
|
+
- K = key, D = door, B = ball (for special levels)
|
116
|
+
- ? = edge of the grid (CANNOT move here - it's the boundary)
|
117
|
+
|
118
|
+
3. HOW MOVEMENT WORKS:
|
119
|
+
- 'forward' = move ONE space in the direction your arrow is pointing
|
120
|
+
- 'left' = turn 90 degrees left (changes arrow direction, doesn't move you)
|
121
|
+
- 'right' = turn 90 degrees right (changes arrow direction, doesn't move you)
|
122
|
+
- You CANNOT move through walls (#) or boundaries (?)
|
123
|
+
|
124
|
+
4. DEBUG MESSAGES:
|
125
|
+
- "Forward blocked by wall" = you tried to move into a wall
|
126
|
+
- "Forward blocked by boundary" = you tried to move outside the grid
|
127
|
+
- "Moved forward" = you successfully moved
|
128
|
+
|
129
|
+
5. IMPORTANT - LIMITED VISIBILITY:
|
130
|
+
- You have LIMITED VISION and can only see a small area around you
|
131
|
+
- The goal (G) might NOT be visible initially - you need to EXPLORE
|
132
|
+
- The ? symbols show areas beyond your current view
|
133
|
+
- You must move around the maze to discover new areas
|
134
|
+
|
135
|
+
6. EXPLORATION STRATEGY:
|
136
|
+
- If you DON'T see the goal (G), you must EXPLORE the maze
|
137
|
+
- Move systematically through empty spaces (.) to reveal new areas
|
138
|
+
- Try to explore unexplored paths rather than revisiting the same spots
|
139
|
+
- Keep track of where you've been to avoid going in circles
|
140
|
+
- When you discover the goal (G), then plan a path to reach it
|
141
|
+
|
142
|
+
7. LEARN FROM PAST ACTIONS:
|
143
|
+
- If an action was blocked, DON'T repeat it immediately
|
144
|
+
- If you keep getting blocked moving forward, try turning left or right
|
145
|
+
- If you're stuck in a pattern, break it by trying a different approach"""
|
146
|
+
|
147
|
+
# Extract debug information to highlight it
|
148
|
+
debug_info = ""
|
149
|
+
if "Debug:" in obs:
|
150
|
+
debug_lines = [
|
151
|
+
line
|
152
|
+
for line in obs.split("\n")
|
153
|
+
if "Debug:" in line or "Last action result:" in line
|
154
|
+
]
|
155
|
+
if debug_lines:
|
156
|
+
debug_info = (
|
157
|
+
"\n\n🚨 IMPORTANT DEBUG INFORMATION:\n"
|
158
|
+
+ "\n".join(f"• {line}" for line in debug_lines)
|
159
|
+
+ "\n"
|
160
|
+
)
|
161
|
+
|
162
|
+
# Build action history string
|
163
|
+
action_history = ""
|
164
|
+
if len(self.history) > 0:
|
165
|
+
action_history = "\n\nRECENT HISTORY (Last 3 Actions):\n"
|
166
|
+
for i, h in enumerate(self.history[-3:], 1):
|
167
|
+
action_history += f"{i}. {h}\n"
|
168
|
+
action_history += "\nBased on this history, avoid repeating failed actions and learn from what worked!\n"
|
169
|
+
|
170
|
+
user_content = f"Current state:\n{obs}{debug_info}{action_history}\nCRITICAL: Check the debug information above! If blocked by wall, you MUST turn or try a different action.\n\nWhat action should I take?"
|
171
|
+
|
172
|
+
messages = [
|
173
|
+
{"role": "system", "content": system_message},
|
174
|
+
{"role": "user", "content": user_content},
|
175
|
+
]
|
176
|
+
|
177
|
+
# Log the prompt
|
178
|
+
prompt_entry = {
|
179
|
+
"turn": turn,
|
180
|
+
"type": "prompt",
|
181
|
+
"messages": messages,
|
182
|
+
"tools": self.tools,
|
183
|
+
"timestamp": datetime.now().isoformat(),
|
184
|
+
}
|
185
|
+
self.debug_log.append(prompt_entry)
|
186
|
+
|
187
|
+
response = await self.llm.respond_async(
|
188
|
+
messages=messages,
|
189
|
+
tools=self.tools,
|
190
|
+
)
|
191
|
+
|
192
|
+
# Log the response
|
193
|
+
response_entry = {
|
194
|
+
"turn": turn,
|
195
|
+
"type": "llm_response",
|
196
|
+
"response": str(response),
|
197
|
+
"response_type": type(response).__name__,
|
198
|
+
"tool_calls": getattr(response, "tool_calls", None),
|
199
|
+
"content": getattr(response, "content", None),
|
200
|
+
"timestamp": datetime.now().isoformat(),
|
201
|
+
}
|
202
|
+
self.debug_log.append(response_entry)
|
203
|
+
|
204
|
+
# Debug: Print response type
|
205
|
+
if self.verbose:
|
206
|
+
print(f"DEBUG: LLM response type: {type(response)}")
|
207
|
+
print(f"DEBUG: LLM response full: {response}")
|
208
|
+
if hasattr(response, "tool_calls"):
|
209
|
+
print(f"DEBUG: Tool calls: {response.tool_calls}")
|
210
|
+
if response.tool_calls:
|
211
|
+
print(f"DEBUG: First tool call: {response.tool_calls[0]}")
|
212
|
+
print(f"DEBUG: First tool call type: {type(response.tool_calls[0])}")
|
213
|
+
if hasattr(response, "content"):
|
214
|
+
print(f"DEBUG: Response content: {response.content}")
|
215
|
+
|
216
|
+
# Parse tool calls - fail fast, no defensive programming
|
217
|
+
if hasattr(response, "tool_calls") and response.tool_calls:
|
218
|
+
tool_call = response.tool_calls[0]
|
219
|
+
# Handle different response formats
|
220
|
+
if isinstance(tool_call, dict):
|
221
|
+
# Dict format from LLM
|
222
|
+
func = tool_call["function"]
|
223
|
+
action = {
|
224
|
+
"name": func["name"],
|
225
|
+
"parameters": json.loads(func["arguments"]),
|
226
|
+
}
|
227
|
+
elif hasattr(tool_call, "function"):
|
228
|
+
# Object format
|
229
|
+
action = {
|
230
|
+
"name": tool_call.function.name,
|
231
|
+
"parameters": json.loads(tool_call.function.arguments),
|
232
|
+
}
|
233
|
+
else:
|
234
|
+
# Unexpected format - fail fast
|
235
|
+
raise ValueError(f"Unexpected tool_call format: {tool_call}")
|
236
|
+
else:
|
237
|
+
# No tool call - fail fast
|
238
|
+
raise ValueError("No tool call returned from LLM")
|
239
|
+
|
240
|
+
# Log the parsed action
|
241
|
+
action_entry = {
|
242
|
+
"turn": turn,
|
243
|
+
"type": "parsed_action",
|
244
|
+
"action": action,
|
245
|
+
"timestamp": datetime.now().isoformat(),
|
246
|
+
}
|
247
|
+
self.debug_log.append(action_entry)
|
248
|
+
|
249
|
+
return action
|
250
|
+
|
251
|
+
@trace_event_async(event_type="minigrid_react_episode")
|
252
|
+
async def run_episode(self, env: MiniGridEnvironment) -> Dict[str, Any]:
|
253
|
+
"""Run one episode in the environment."""
|
254
|
+
# Initialize
|
255
|
+
obs = await env.initialize()
|
256
|
+
task_description = env.task_instance.impetus.instructions
|
257
|
+
|
258
|
+
if self.verbose:
|
259
|
+
print(f"\nTask: {task_description}")
|
260
|
+
print(f"Initial observation:\n{self._format_observation(obs)}\n")
|
261
|
+
|
262
|
+
success = False
|
263
|
+
total_reward = 0.0
|
264
|
+
last_reward = 0.0
|
265
|
+
|
266
|
+
for turn in range(self.max_turns):
|
267
|
+
# Format observation
|
268
|
+
formatted_obs = self._format_observation(obs)
|
269
|
+
|
270
|
+
# Log the observation
|
271
|
+
obs_entry = {
|
272
|
+
"turn": turn,
|
273
|
+
"type": "observation",
|
274
|
+
"raw_obs": obs,
|
275
|
+
"formatted_obs": formatted_obs,
|
276
|
+
"timestamp": datetime.now().isoformat(),
|
277
|
+
}
|
278
|
+
self.debug_log.append(obs_entry)
|
279
|
+
|
280
|
+
# Get agent decision
|
281
|
+
action = await self.decide(formatted_obs, task_description, turn)
|
282
|
+
|
283
|
+
if self.verbose:
|
284
|
+
print(f"\nTurn {turn + 1}:")
|
285
|
+
print(f"Action: {action['name']}")
|
286
|
+
if "parameters" in action:
|
287
|
+
print(f"Parameters: {action['parameters']}")
|
288
|
+
|
289
|
+
# Check for termination
|
290
|
+
if action["name"] == "terminate":
|
291
|
+
if self.verbose:
|
292
|
+
print(f"Agent terminated: {action['parameters']['reason']}")
|
293
|
+
break
|
294
|
+
|
295
|
+
# Execute action
|
296
|
+
tool_call = {"tool": action["name"], "args": action["parameters"]}
|
297
|
+
|
298
|
+
# Log the tool call
|
299
|
+
tool_call_entry = {
|
300
|
+
"turn": turn,
|
301
|
+
"type": "tool_call",
|
302
|
+
"tool_call": tool_call,
|
303
|
+
"timestamp": datetime.now().isoformat(),
|
304
|
+
}
|
305
|
+
self.debug_log.append(tool_call_entry)
|
306
|
+
|
307
|
+
# Debug: Print tool call
|
308
|
+
if self.verbose:
|
309
|
+
print(f"DEBUG: Sending tool_call: {tool_call}")
|
310
|
+
|
311
|
+
obs = await env.step(tool_call)
|
312
|
+
|
313
|
+
# Log the environment response
|
314
|
+
env_response_entry = {
|
315
|
+
"turn": turn,
|
316
|
+
"type": "env_response",
|
317
|
+
"response": obs,
|
318
|
+
"timestamp": datetime.now().isoformat(),
|
319
|
+
}
|
320
|
+
self.debug_log.append(env_response_entry)
|
321
|
+
|
322
|
+
# Debug: Print response
|
323
|
+
if self.verbose:
|
324
|
+
print(f"DEBUG: Environment response keys: {list(obs.keys())}")
|
325
|
+
if "error" in obs:
|
326
|
+
print(f"DEBUG: ERROR: {obs['error']}")
|
327
|
+
|
328
|
+
# Track history with result
|
329
|
+
action_taken = action["parameters"]["action"]
|
330
|
+
action_reasoning = action["parameters"]["reasoning"]
|
331
|
+
action_result = obs["last_action_result"]
|
332
|
+
|
333
|
+
# Extract position info if available
|
334
|
+
position_info = ""
|
335
|
+
if "observation" in obs:
|
336
|
+
lines = obs["observation"].split("\n")
|
337
|
+
for line in lines:
|
338
|
+
if "Agent Position" in line:
|
339
|
+
position_info = f" -> {line}"
|
340
|
+
break
|
341
|
+
|
342
|
+
history_entry = f"Action: {action_taken} | Reasoning: {action_reasoning} | Result: {action_result}{position_info}"
|
343
|
+
self.history.append(history_entry)
|
344
|
+
|
345
|
+
# Update metrics
|
346
|
+
total_reward = obs["total_reward"]
|
347
|
+
last_reward = obs["reward_last"]
|
348
|
+
|
349
|
+
if self.verbose:
|
350
|
+
print(f"Reward: {last_reward:.3f} (Total: {total_reward:.3f})")
|
351
|
+
if "observation" in obs:
|
352
|
+
# Just print position line for brevity
|
353
|
+
lines = obs["observation"].split("\n")
|
354
|
+
for line in lines:
|
355
|
+
if "Agent Position" in line:
|
356
|
+
print(line)
|
357
|
+
break
|
358
|
+
|
359
|
+
# Check if terminated
|
360
|
+
if obs["terminated"]:
|
361
|
+
success = obs["success"] or "goal" in str(obs).lower()
|
362
|
+
if self.verbose:
|
363
|
+
print(f"\nEpisode ended! Success: {success}, Final Reward: {total_reward:.3f}")
|
364
|
+
break
|
365
|
+
|
366
|
+
# Get final metrics
|
367
|
+
final_obs = await env.terminate()
|
368
|
+
|
369
|
+
# Log final episode summary
|
370
|
+
episode_summary = {
|
371
|
+
"type": "episode_summary",
|
372
|
+
"success": success,
|
373
|
+
"turns": turn + 1,
|
374
|
+
"total_reward": total_reward,
|
375
|
+
"final_position": final_obs["final_position"],
|
376
|
+
"total_steps": final_obs["total_steps"],
|
377
|
+
"debug_log_entries": len(self.debug_log),
|
378
|
+
"timestamp": datetime.now().isoformat(),
|
379
|
+
}
|
380
|
+
self.debug_log.append(episode_summary)
|
381
|
+
|
382
|
+
return {
|
383
|
+
"success": success,
|
384
|
+
"turns": turn + 1,
|
385
|
+
"total_reward": total_reward,
|
386
|
+
"final_position": final_obs["final_position"],
|
387
|
+
"total_steps": final_obs["total_steps"],
|
388
|
+
"debug_log": self.debug_log, # Include full debug log
|
389
|
+
}
|
390
|
+
|
391
|
+
|
392
|
+
# --- Evaluation Function ---
|
393
|
+
@trace_event_async(event_type="eval_minigrid_react")
|
394
|
+
async def eval_minigrid_react(
|
395
|
+
model_name: str = "gpt-4-mini",
|
396
|
+
num_tasks: int = 5,
|
397
|
+
difficulty: str = "easy",
|
398
|
+
verbose: bool = True,
|
399
|
+
) -> Dict[str, Any]:
|
400
|
+
"""Evaluate ReAct agent on MiniGrid tasks."""
|
401
|
+
# Generate task set
|
402
|
+
taskset = await create_minigrid_taskset(
|
403
|
+
num_tasks_per_difficulty={difficulty: num_tasks}, seed=42
|
404
|
+
)
|
405
|
+
|
406
|
+
# Initialize LLM and agent
|
407
|
+
llm = LM(model_name=model_name, formatting_model_name=model_name, temperature=0.7)
|
408
|
+
agent = MiniGridReActAgent(llm, max_turns=15, verbose=verbose) # Reduced max turns
|
409
|
+
|
410
|
+
# Create debug logs directory
|
411
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
412
|
+
debug_dir = f"minigrid_debug_logs_{timestamp}"
|
413
|
+
os.makedirs(debug_dir, exist_ok=True)
|
414
|
+
|
415
|
+
# Run evaluation
|
416
|
+
results = []
|
417
|
+
all_debug_logs = []
|
418
|
+
|
419
|
+
for i, task in enumerate(taskset.instances[:num_tasks]):
|
420
|
+
if verbose:
|
421
|
+
print(f"\n{'=' * 60}")
|
422
|
+
print(f"Task {i + 1}/{num_tasks}: {task.metadata.env_name}")
|
423
|
+
print(f"{'=' * 60}")
|
424
|
+
|
425
|
+
# Create environment
|
426
|
+
env = MiniGridEnvironment(task)
|
427
|
+
|
428
|
+
# Run episode
|
429
|
+
result = await agent.run_episode(env)
|
430
|
+
result["task_id"] = str(task.id)
|
431
|
+
result["env_name"] = task.metadata.env_name
|
432
|
+
result["difficulty"] = task.metadata.difficulty
|
433
|
+
|
434
|
+
# Save debug log for this task
|
435
|
+
debug_log = result.pop("debug_log", []) # Remove from result to avoid duplication
|
436
|
+
debug_log_file = os.path.join(
|
437
|
+
debug_dir, f"task_{i + 1}_{model_name.replace('.', '_')}_debug.json"
|
438
|
+
)
|
439
|
+
with open(debug_log_file, "w") as f:
|
440
|
+
json.dump(
|
441
|
+
{
|
442
|
+
"task_info": {
|
443
|
+
"task_id": result["task_id"],
|
444
|
+
"env_name": result["env_name"],
|
445
|
+
"difficulty": result["difficulty"],
|
446
|
+
"model": model_name,
|
447
|
+
},
|
448
|
+
"result": result,
|
449
|
+
"debug_log": debug_log,
|
450
|
+
},
|
451
|
+
f,
|
452
|
+
indent=2,
|
453
|
+
default=str,
|
454
|
+
)
|
455
|
+
|
456
|
+
all_debug_logs.append(debug_log)
|
457
|
+
results.append(result)
|
458
|
+
|
459
|
+
if verbose:
|
460
|
+
print(f"\nResult: {result}")
|
461
|
+
print(f"Debug log saved to: {debug_log_file}")
|
462
|
+
|
463
|
+
# Save summary debug info
|
464
|
+
summary_debug_file = os.path.join(debug_dir, f"summary_{model_name.replace('.', '_')}.json")
|
465
|
+
with open(summary_debug_file, "w") as f:
|
466
|
+
json.dump(
|
467
|
+
{
|
468
|
+
"model": model_name,
|
469
|
+
"timestamp": timestamp,
|
470
|
+
"all_debug_logs": all_debug_logs,
|
471
|
+
},
|
472
|
+
f,
|
473
|
+
indent=2,
|
474
|
+
default=str,
|
475
|
+
)
|
476
|
+
|
477
|
+
# Compute statistics
|
478
|
+
successes = [r["success"] for r in results]
|
479
|
+
success_rate = sum(successes) / len(successes) if successes else 0
|
480
|
+
avg_reward = sum(r["total_reward"] for r in results) / len(results) if results else 0
|
481
|
+
avg_steps = sum(r["total_steps"] for r in results) / len(results) if results else 0
|
482
|
+
|
483
|
+
summary = {
|
484
|
+
"model": model_name,
|
485
|
+
"num_tasks": len(results),
|
486
|
+
"difficulty": difficulty,
|
487
|
+
"success_rate": success_rate,
|
488
|
+
"avg_reward": avg_reward,
|
489
|
+
"avg_steps": avg_steps,
|
490
|
+
"results": results,
|
491
|
+
"debug_dir": debug_dir,
|
492
|
+
}
|
493
|
+
|
494
|
+
if verbose:
|
495
|
+
print(f"\n{'=' * 60}")
|
496
|
+
print("SUMMARY")
|
497
|
+
print(f"{'=' * 60}")
|
498
|
+
print(f"Success Rate: {success_rate:.1%}")
|
499
|
+
print(f"Average Reward: {avg_reward:.3f}")
|
500
|
+
print(f"Average Steps: {avg_steps:.1f}")
|
501
|
+
|
502
|
+
return summary
|
503
|
+
|
504
|
+
|
505
|
+
# --- Main ---
|
506
|
+
async def main():
|
507
|
+
"""Run the demo."""
|
508
|
+
print("Testing MiniGrid ReAct Agent")
|
509
|
+
print("=" * 60)
|
510
|
+
|
511
|
+
# Models to test
|
512
|
+
models = ["gpt-4.1-nano", "gpt-4.1-mini"]
|
513
|
+
all_results = {}
|
514
|
+
|
515
|
+
for model in models:
|
516
|
+
print(f"\n\n{'=' * 60}")
|
517
|
+
print(f"Testing model: {model}")
|
518
|
+
print(f"{'=' * 60}")
|
519
|
+
|
520
|
+
# Run evaluation
|
521
|
+
summary = await eval_minigrid_react(
|
522
|
+
model_name=model,
|
523
|
+
num_tasks=5, # Run 5 tasks per model
|
524
|
+
difficulty="easy",
|
525
|
+
verbose=True,
|
526
|
+
)
|
527
|
+
|
528
|
+
all_results[model] = summary
|
529
|
+
|
530
|
+
# Print model summary
|
531
|
+
print(f"\n\nSummary for {model}:")
|
532
|
+
print(f"Success Rate: {summary['success_rate']:.1%}")
|
533
|
+
print(f"Average Reward: {summary['avg_reward']:.3f}")
|
534
|
+
print(f"Average Steps: {summary['avg_steps']:.1f}")
|
535
|
+
|
536
|
+
# Compare results
|
537
|
+
print(f"\n\n{'=' * 60}")
|
538
|
+
print("COMPARISON SUMMARY")
|
539
|
+
print(f"{'=' * 60}")
|
540
|
+
print(f"{'Model':<20} {'Success Rate':<15} {'Avg Reward':<15} {'Avg Steps':<10}")
|
541
|
+
print("-" * 60)
|
542
|
+
for model, summary in all_results.items():
|
543
|
+
print(
|
544
|
+
f"{model:<20} {summary['success_rate']:.1%}{'':10} "
|
545
|
+
f"{summary['avg_reward']:.3f}{'':10} "
|
546
|
+
f"{summary['avg_steps']:.1f}"
|
547
|
+
)
|
548
|
+
|
549
|
+
# Detailed results
|
550
|
+
print("\n\nDetailed Results:")
|
551
|
+
for model, summary in all_results.items():
|
552
|
+
print(f"\n{model}:")
|
553
|
+
for i, result in enumerate(summary["results"]):
|
554
|
+
print(
|
555
|
+
f" Task {i + 1}: Success={result['success']}, "
|
556
|
+
f"Reward={result['total_reward']:.3f}, "
|
557
|
+
f"Steps={result['total_steps']}"
|
558
|
+
)
|
559
|
+
|
560
|
+
|
561
|
+
if __name__ == "__main__":
|
562
|
+
asyncio.run(main())
|