synth-ai 0.2.0__py3-none-any.whl ā 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms ā lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms ā lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms ā lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms ā lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms ā lm}/config.py +2 -1
- synth_ai/{zyk/lms ā lm}/constants.py +2 -2
- synth_ai/{zyk/lms ā lm}/core/all.py +10 -10
- synth_ai/{zyk/lms ā lm}/core/main.py +57 -33
- synth_ai/{zyk/lms ā lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms ā lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms ā lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms ā lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms ā lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms ā lm}/vendors/core/gemini_api.py +35 -32
- synth_ai/{zyk/lms ā lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms ā lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms ā lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms ā lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms ā lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms ā lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms ā lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms ā lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms ā lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms ā lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms ā lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- {synth_ai-0.2.0.dist-info ā synth_ai-0.2.1.dev0.dist-info}/WHEEL +1 -1
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.2.0.dist-info/METADATA +0 -36
- synth_ai-0.2.0.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py ā environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching ā lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core ā lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost ā lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs ā lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors ā lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core ā lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local ā lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported ā lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.2.0.dist-info ā synth_ai-0.2.1.dev0.dist-info/licenses}/LICENSE +0 -0
- {synth_ai-0.2.0.dist-info ā synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1110 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Test script to run ReAct agents against Crafter environment on synth service (port 8901)
|
4
|
+
Tests on multiple easy Crafter instances with enhanced debugging
|
5
|
+
"""
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import json
|
9
|
+
import uuid
|
10
|
+
import math
|
11
|
+
import argparse
|
12
|
+
import toml
|
13
|
+
import logging
|
14
|
+
from datetime import datetime
|
15
|
+
from typing import Dict, Any, Optional, List, Set
|
16
|
+
from pydantic import BaseModel, Field
|
17
|
+
from httpx import AsyncClient
|
18
|
+
import sys
|
19
|
+
import os
|
20
|
+
from pathlib import Path
|
21
|
+
from tqdm.asyncio import tqdm_asyncio
|
22
|
+
|
23
|
+
# Add the src directory to the path
|
24
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "src"))
|
25
|
+
|
26
|
+
from synth_ai.zyk import LM
|
27
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
28
|
+
import numpy as np
|
29
|
+
|
30
|
+
|
31
|
+
# --- Configuration Class ---
|
32
|
+
class CrafterConfig:
|
33
|
+
"""Configuration for Crafter evaluation."""
|
34
|
+
|
35
|
+
def __init__(self, config_path: Optional[str] = None):
|
36
|
+
# Default values
|
37
|
+
self.model_name: Optional[str] = None # Must be provided via config or CLI
|
38
|
+
self.num_instances = 3
|
39
|
+
self.max_turns = 20
|
40
|
+
self.difficulty = "easy"
|
41
|
+
self.service_base_url = "http://localhost:8901"
|
42
|
+
self.service_timeout = 30.0
|
43
|
+
self.seed = 42
|
44
|
+
self.save_traces = True
|
45
|
+
self.save_detailed_results = True
|
46
|
+
self.verbose = False
|
47
|
+
|
48
|
+
# Custom OpenAI endpoint support
|
49
|
+
self.custom_openai_base_url = None # e.g., "https://lora-inference-service-xyz.modal.run"
|
50
|
+
self.custom_openai_api_key = "dummy" # Default dummy key for custom endpoints
|
51
|
+
|
52
|
+
# Load from TOML if provided
|
53
|
+
if config_path and os.path.exists(config_path):
|
54
|
+
self.load_from_toml(config_path)
|
55
|
+
|
56
|
+
# Fail fast if no model name provided
|
57
|
+
# Configure custom OpenAI endpoint if specified
|
58
|
+
self._configure_custom_openai()
|
59
|
+
|
60
|
+
def load_from_toml(self, config_path: str):
|
61
|
+
"""Load configuration from TOML file."""
|
62
|
+
config = toml.load(config_path)
|
63
|
+
|
64
|
+
# Extract eval settings
|
65
|
+
eval_config = config.get("eval", {})
|
66
|
+
self.model_name = eval_config.get("model_name", self.model_name)
|
67
|
+
self.num_instances = eval_config.get("episodes", self.num_instances)
|
68
|
+
self.max_turns = eval_config.get("max_steps", self.max_turns)
|
69
|
+
self.difficulty = eval_config.get("difficulty", self.difficulty)
|
70
|
+
self.seed = eval_config.get("seed", self.seed)
|
71
|
+
|
72
|
+
# Extract service settings
|
73
|
+
service_config = config.get("service", {})
|
74
|
+
self.service_base_url = service_config.get("base_url", self.service_base_url)
|
75
|
+
self.service_timeout = service_config.get("timeout", self.service_timeout)
|
76
|
+
|
77
|
+
# Extract output settings
|
78
|
+
output_config = config.get("output", {})
|
79
|
+
self.save_traces = output_config.get("save_traces", self.save_traces)
|
80
|
+
self.save_detailed_results = output_config.get(
|
81
|
+
"save_detailed_results", self.save_detailed_results
|
82
|
+
)
|
83
|
+
|
84
|
+
# Extract custom OpenAI endpoint settings
|
85
|
+
openai_config = config.get("openai", {})
|
86
|
+
self.custom_openai_base_url = openai_config.get("base_url", self.custom_openai_base_url)
|
87
|
+
self.custom_openai_api_key = openai_config.get("api_key", self.custom_openai_api_key)
|
88
|
+
|
89
|
+
def _configure_custom_openai(self):
|
90
|
+
"""Configure environment variables for custom OpenAI endpoint if specified."""
|
91
|
+
if self.custom_openai_base_url:
|
92
|
+
# Ensure the base URL ends with /v1 for OpenAI compatibility
|
93
|
+
base_url = self.custom_openai_base_url.rstrip("/")
|
94
|
+
if not base_url.endswith("/v1"):
|
95
|
+
base_url += "/v1"
|
96
|
+
|
97
|
+
# Set environment variables for OpenAI SDK
|
98
|
+
os.environ["OPENAI_BASE_URL"] = base_url
|
99
|
+
os.environ["OPENAI_API_KEY"] = self.custom_openai_api_key
|
100
|
+
|
101
|
+
print(f"š§ Configured custom OpenAI endpoint: {base_url}")
|
102
|
+
print(f" API Key: {self.custom_openai_api_key}")
|
103
|
+
|
104
|
+
# Auto-detect if this looks like a fine-tuned model and add ft: regex support
|
105
|
+
if self.model_name and (
|
106
|
+
self.model_name.startswith("ft:") or "lora" in self.model_name.lower()
|
107
|
+
):
|
108
|
+
self._add_ft_regex_support()
|
109
|
+
|
110
|
+
def _add_ft_regex_support(self):
|
111
|
+
"""Add ft: regex pattern to OpenAI naming regexes if not already present."""
|
112
|
+
try:
|
113
|
+
import re
|
114
|
+
from synth_ai.zyk.lms.core import vendor_clients
|
115
|
+
|
116
|
+
# Check if ft: pattern already exists
|
117
|
+
ft_pattern = re.compile(r"^ft:.*$")
|
118
|
+
if not any(
|
119
|
+
pattern.pattern == ft_pattern.pattern
|
120
|
+
for pattern in vendor_clients.openai_naming_regexes
|
121
|
+
):
|
122
|
+
# Add ft: pattern at the beginning to catch all fine-tuned models
|
123
|
+
vendor_clients.openai_naming_regexes.insert(0, ft_pattern)
|
124
|
+
print(f"ā
Added ft:* regex pattern for fine-tuned model support")
|
125
|
+
except Exception as e:
|
126
|
+
print(f"ā ļø Warning: Could not add ft: regex pattern: {e}")
|
127
|
+
|
128
|
+
def set_custom_endpoint(self, base_url: str, api_key: str = "dummy"):
|
129
|
+
"""Programmatically set custom OpenAI endpoint."""
|
130
|
+
self.custom_openai_base_url = base_url
|
131
|
+
self.custom_openai_api_key = api_key
|
132
|
+
self._configure_custom_openai()
|
133
|
+
|
134
|
+
|
135
|
+
# --- Global Config ---
|
136
|
+
config = CrafterConfig()
|
137
|
+
|
138
|
+
|
139
|
+
# --- Helper to build crafter semantic mapping ---
|
140
|
+
def get_crafter_semantic_mapping():
|
141
|
+
"""Build the crafter semantic ID to item name mapping."""
|
142
|
+
try:
|
143
|
+
import crafter
|
144
|
+
import itertools
|
145
|
+
|
146
|
+
# Create a dummy env to get ID mappings
|
147
|
+
dummyenv = None
|
148
|
+
try:
|
149
|
+
dummyenv = crafter.Env()
|
150
|
+
max_id = (
|
151
|
+
max(
|
152
|
+
max(dummyenv._world._mat_ids.values()),
|
153
|
+
max(dummyenv._sem_view._obj_ids.values()),
|
154
|
+
)
|
155
|
+
+ 1
|
156
|
+
)
|
157
|
+
id_to_item = ["void"] * max_id
|
158
|
+
for name, ind in itertools.chain(
|
159
|
+
dummyenv._world._mat_ids.items(), dummyenv._sem_view._obj_ids.items()
|
160
|
+
):
|
161
|
+
if name is None:
|
162
|
+
clean = "none"
|
163
|
+
elif hasattr(name, "__name__"):
|
164
|
+
clean = name.__name__
|
165
|
+
else:
|
166
|
+
clean = str(name)
|
167
|
+
id_to_item[ind] = clean.lower()
|
168
|
+
player_idx = id_to_item.index("player")
|
169
|
+
return id_to_item, player_idx
|
170
|
+
finally:
|
171
|
+
if dummyenv:
|
172
|
+
try:
|
173
|
+
dummyenv.close()
|
174
|
+
except Exception:
|
175
|
+
pass
|
176
|
+
del dummyenv
|
177
|
+
except ImportError:
|
178
|
+
# Fallback if crafter is not available
|
179
|
+
return None, None
|
180
|
+
|
181
|
+
|
182
|
+
def format_semantic_map_view(obs_data: Dict[str, Any], view_size: int = 7) -> str:
|
183
|
+
"""Format a semantic map view around the player (ASCII)."""
|
184
|
+
try:
|
185
|
+
# Get mapping list
|
186
|
+
id_to_item, _ = get_crafter_semantic_mapping()
|
187
|
+
if id_to_item is None:
|
188
|
+
return "Map view unavailable (crafter not installed)"
|
189
|
+
|
190
|
+
semantic_map = obs_data.get("semantic_map")
|
191
|
+
player_position = obs_data.get("player_position", [0, 0])
|
192
|
+
|
193
|
+
if semantic_map is None:
|
194
|
+
return "Map view unavailable (no semantic map data)"
|
195
|
+
|
196
|
+
# Ensure numpy array with 2 dimensions
|
197
|
+
sem_arr = np.asarray(semantic_map)
|
198
|
+
if sem_arr.ndim == 1:
|
199
|
+
# Probably flattened; try to infer square size
|
200
|
+
size = int(np.sqrt(sem_arr.size))
|
201
|
+
sem_arr = sem_arr.reshape(size, size)
|
202
|
+
elif sem_arr.ndim != 2:
|
203
|
+
return "Map view unavailable (invalid map dimensionality)"
|
204
|
+
|
205
|
+
px, py = map(int, player_position)
|
206
|
+
half = view_size // 2
|
207
|
+
|
208
|
+
rows = []
|
209
|
+
visible = set()
|
210
|
+
for dy in range(-half, half + 1):
|
211
|
+
row_tokens = []
|
212
|
+
for dx in range(-half, half + 1):
|
213
|
+
x, y = px + dx, py + dy
|
214
|
+
if 0 <= x < sem_arr.shape[0] and 0 <= y < sem_arr.shape[1]:
|
215
|
+
if dx == 0 and dy == 0:
|
216
|
+
token = "player"
|
217
|
+
else:
|
218
|
+
idx = int(sem_arr[x, y])
|
219
|
+
token = id_to_item[idx] if idx < len(id_to_item) else "?"
|
220
|
+
else:
|
221
|
+
token = "void"
|
222
|
+
row_tokens.append(token)
|
223
|
+
if token not in {"void", "player"}:
|
224
|
+
visible.add(token)
|
225
|
+
rows.append(" ".join(row_tokens))
|
226
|
+
|
227
|
+
map_view = f"\nLocal Map View ({view_size}x{view_size}):\n" + "\n".join(rows)
|
228
|
+
if visible:
|
229
|
+
map_view += "\nVisible items: " + ", ".join(sorted(visible))
|
230
|
+
else:
|
231
|
+
map_view += "\nNo special items visible (mostly grass/empty)"
|
232
|
+
return map_view
|
233
|
+
except Exception as e:
|
234
|
+
return f"Map view error: {e}"
|
235
|
+
|
236
|
+
|
237
|
+
# --- Shaped Reward Configuration ---
|
238
|
+
# K-values for shaped reward calculation: reward = sum(K * log(count)) for each achievement
|
239
|
+
ACHIEVEMENT_K_VALUES = {
|
240
|
+
"collect_coal": 3.0,
|
241
|
+
"collect_diamond": 100.0,
|
242
|
+
"collect_drink": 0.1,
|
243
|
+
"collect_iron": 10.0,
|
244
|
+
"collect_sapling": 0.1,
|
245
|
+
"collect_stone": 1.0,
|
246
|
+
"collect_wood": 1.0,
|
247
|
+
"defeat_skeleton": 1.0,
|
248
|
+
"defeat_zombie": 1.0,
|
249
|
+
"eat_cow": 1.0,
|
250
|
+
"eat_plant": 0.1,
|
251
|
+
"make_iron_pickaxe": 30.0,
|
252
|
+
"make_iron_sword": 30.0,
|
253
|
+
"make_stone_pickaxe": 10.0,
|
254
|
+
"make_stone_sword": 10.0,
|
255
|
+
"make_wood_pickaxe": 3.0,
|
256
|
+
"make_wood_sword": 3.0,
|
257
|
+
"place_furnace": 10.0,
|
258
|
+
"place_plant": 0.1,
|
259
|
+
"place_stone": 1.0,
|
260
|
+
"place_table": 3.0,
|
261
|
+
"wake_up": 0.1,
|
262
|
+
}
|
263
|
+
|
264
|
+
|
265
|
+
# --- Tool Definitions ---
|
266
|
+
class CrafterActionArgs(BaseModel):
|
267
|
+
"""Arguments for crafter actions."""
|
268
|
+
|
269
|
+
actions: List[str] = Field(
|
270
|
+
description="List of 1-5 action names to execute in sequence (e.g., ['move_up', 'do', 'mine_down'])"
|
271
|
+
)
|
272
|
+
reasoning: str = Field(description="Brief explanation of why these actions were chosen")
|
273
|
+
|
274
|
+
|
275
|
+
class TerminateArgs(BaseModel):
|
276
|
+
"""Arguments for termination."""
|
277
|
+
|
278
|
+
reason: str = Field(description="Reason for termination")
|
279
|
+
|
280
|
+
|
281
|
+
class CrafterActionTool(BaseTool):
|
282
|
+
"""Tool for performing actions in the Crafter environment."""
|
283
|
+
|
284
|
+
name: str = "interact"
|
285
|
+
arguments: type[BaseModel] = CrafterActionArgs
|
286
|
+
description: str = "Perform 1-5 actions in sequence in the Crafter environment."
|
287
|
+
|
288
|
+
|
289
|
+
class TerminateTool(BaseTool):
|
290
|
+
"""Tool to terminate the episode."""
|
291
|
+
|
292
|
+
name: str = "terminate"
|
293
|
+
arguments: type[BaseModel] = TerminateArgs
|
294
|
+
description: str = "End the episode when finished or no progress can be made."
|
295
|
+
|
296
|
+
|
297
|
+
# --- Shaped Reward Helper ---
|
298
|
+
def calculate_shaped_reward(achievement_counts: Dict[str, int]) -> Dict[str, Any]:
|
299
|
+
"""Calculate shaped reward using K * log(count) for each achievement."""
|
300
|
+
total_reward = 0.0
|
301
|
+
reward_breakdown = {}
|
302
|
+
|
303
|
+
for achievement, count in achievement_counts.items():
|
304
|
+
if count > 0 and achievement in ACHIEVEMENT_K_VALUES:
|
305
|
+
k_value = ACHIEVEMENT_K_VALUES[achievement]
|
306
|
+
# Use log(count + 1) to handle count=0 case gracefully
|
307
|
+
reward_contribution = k_value * math.log(count + 1)
|
308
|
+
total_reward += reward_contribution
|
309
|
+
reward_breakdown[achievement] = {
|
310
|
+
"count": count,
|
311
|
+
"k_value": k_value,
|
312
|
+
"contribution": reward_contribution,
|
313
|
+
}
|
314
|
+
|
315
|
+
return {"total_shaped_reward": total_reward, "breakdown": reward_breakdown}
|
316
|
+
|
317
|
+
|
318
|
+
# --- Base ReAct Agent ---
|
319
|
+
class BaseReActAgent:
|
320
|
+
"""Base ReAct agent for environment interaction."""
|
321
|
+
|
322
|
+
def __init__(self, llm: LM, max_turns: int = 20, verbose: bool = False):
|
323
|
+
self.llm = llm
|
324
|
+
self.max_turns = max_turns
|
325
|
+
self.verbose = verbose
|
326
|
+
self.history = []
|
327
|
+
self.system_name = "base-react-agent"
|
328
|
+
|
329
|
+
# Define tools in OpenAI format
|
330
|
+
self.tools = [
|
331
|
+
CrafterActionTool(),
|
332
|
+
TerminateTool(),
|
333
|
+
]
|
334
|
+
|
335
|
+
async def decide(self, obs: str, system_message: str, turn: int) -> Dict[str, Any]:
|
336
|
+
"""Get agent decision based on observation."""
|
337
|
+
# Create conversation context
|
338
|
+
context = f"Turn {turn + 1}/{self.max_turns}\n\n{obs}"
|
339
|
+
# Generate response using LLM
|
340
|
+
response_obj = await self.llm.respond_async(
|
341
|
+
system_message=system_message, user_message=context, tools=self.tools
|
342
|
+
)
|
343
|
+
|
344
|
+
tool_calls = response_obj.tool_calls
|
345
|
+
|
346
|
+
# Handle case where tool_calls is None or empty (graceful fallback)
|
347
|
+
if not tool_calls:
|
348
|
+
if self.verbose:
|
349
|
+
print(f"[WARNING] No tool calls returned by LLM, using default action")
|
350
|
+
return {
|
351
|
+
"name": "interact",
|
352
|
+
"parameters": {
|
353
|
+
"actions": ["do"],
|
354
|
+
"reasoning": "Default action - no tool call received",
|
355
|
+
},
|
356
|
+
}
|
357
|
+
|
358
|
+
tool_call_data = tool_calls[0]
|
359
|
+
|
360
|
+
# Handle both dict and object formats
|
361
|
+
if isinstance(tool_call_data, dict):
|
362
|
+
tool_name = tool_call_data["function"]["name"]
|
363
|
+
tool_args_str = tool_call_data["function"]["arguments"]
|
364
|
+
else:
|
365
|
+
tool_name = tool_call_data.function.name
|
366
|
+
tool_args_str = tool_call_data.function.arguments
|
367
|
+
|
368
|
+
tool_arguments = json.loads(tool_args_str)
|
369
|
+
|
370
|
+
return {"name": tool_name, "parameters": tool_arguments}
|
371
|
+
|
372
|
+
|
373
|
+
# --- Crafter ReAct Agent ---
|
374
|
+
class CrafterReActAgent(BaseReActAgent):
|
375
|
+
"""ReAct agent for Crafter environment."""
|
376
|
+
|
377
|
+
def __init__(self, llm: LM, max_turns: int = 20, verbose: bool = False):
|
378
|
+
super().__init__(llm, max_turns, verbose)
|
379
|
+
self.system_name = "crafter-react-agent"
|
380
|
+
|
381
|
+
def get_system_message(self) -> str:
|
382
|
+
return """You are CrafterAgent playing Crafter survival environment. Your goal is to unlock as many achievements as possible while staying alive.
|
383
|
+
|
384
|
+
You will see a semantic map view showing your surroundings. Use this to navigate toward resources.
|
385
|
+
|
386
|
+
Key mechanics:
|
387
|
+
⢠'do' action: collect wood from trees, stone from deposits, food from cows/plants
|
388
|
+
⢠'do' does nothing on grass/water - move to find resources first
|
389
|
+
⢠Craft progression: wood ā table ā wood_pickaxe ā stone ā stone_pickaxe ā iron tools
|
390
|
+
⢠Sleep when energy low to restore and unlock wake_up achievement
|
391
|
+
⢠Use semantic map view to navigate toward resources you can see
|
392
|
+
|
393
|
+
Available actions: move_left, move_right, move_up, move_down, do, sleep, place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword, noop
|
394
|
+
|
395
|
+
Strategy:
|
396
|
+
1. Look at the semantic map to see what's around you
|
397
|
+
2. Move toward trees to collect wood with 'do'
|
398
|
+
3. Once you have wood, place a table to enable crafting
|
399
|
+
4. Make a wood pickaxe to collect stone more efficiently
|
400
|
+
5. Progress to stone pickaxe, then iron tools
|
401
|
+
6. Eat food when health is low, sleep when energy is low
|
402
|
+
|
403
|
+
You should provide 1-5 actions in sequence for efficient gameplay. Use the semantic map view to navigate toward visible resources.
|
404
|
+
|
405
|
+
Example good action sequences:
|
406
|
+
- ['move_right', 'move_right', 'do'] (move to tree and collect wood)
|
407
|
+
- ['place_table', 'make_wood_pickaxe'] (craft progression)
|
408
|
+
- ['move_up', 'do', 'move_down', 'do'] (collect from multiple resources)
|
409
|
+
|
410
|
+
Be strategic and use the map view to find resources! Focus on unlocking achievements."""
|
411
|
+
|
412
|
+
def format_observation(self, obs: Dict[str, Any]) -> str:
|
413
|
+
"""Format observation for Crafter with rich context."""
|
414
|
+
# Extract key information from observation
|
415
|
+
health = obs.get("health", 0)
|
416
|
+
inventory = obs.get("inventory", {})
|
417
|
+
|
418
|
+
# Extract health from inventory if not in main obs
|
419
|
+
if health == 0 and "health" in inventory:
|
420
|
+
health = inventory["health"]
|
421
|
+
|
422
|
+
# Format inventory items (exclude health since we show it separately)
|
423
|
+
inventory_items = []
|
424
|
+
for item, count in inventory.items():
|
425
|
+
if count > 0 and item != "health":
|
426
|
+
inventory_items.append(f"{item}: {count}")
|
427
|
+
|
428
|
+
inventory_str = ", ".join(inventory_items) if inventory_items else "empty"
|
429
|
+
|
430
|
+
# Get achievements
|
431
|
+
achievements = obs.get("achievements") or obs.get("achievements_status", {})
|
432
|
+
unlocked_achievements = [name for name, unlocked in achievements.items() if unlocked]
|
433
|
+
achievements_str = ", ".join(unlocked_achievements) if unlocked_achievements else "none"
|
434
|
+
|
435
|
+
# Get position and other state
|
436
|
+
position = obs.get("position", [0, 0])
|
437
|
+
player_position = obs.get("player_position", position)
|
438
|
+
player_direction = obs.get("player_direction", [0, 1])
|
439
|
+
num_steps = obs.get("num_steps_taken", 0)
|
440
|
+
|
441
|
+
# Check termination status
|
442
|
+
terminated = obs.get("terminated", False)
|
443
|
+
|
444
|
+
# Get semantic map view
|
445
|
+
map_view = format_semantic_map_view(obs, view_size=5)
|
446
|
+
|
447
|
+
return (
|
448
|
+
f"Crafter Game State:\n"
|
449
|
+
f"Step: {num_steps}\n"
|
450
|
+
f"Health: {health}\n"
|
451
|
+
f"Position: {player_position}\n"
|
452
|
+
f"Direction: {player_direction}\n"
|
453
|
+
f"Inventory: {inventory_str}\n"
|
454
|
+
f"Achievements: {achievements_str}\n"
|
455
|
+
f"Terminated: {terminated}\n"
|
456
|
+
f"{map_view}\n\n"
|
457
|
+
f"Available actions: move_left, move_right, move_up, move_down, do, sleep, place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword, noop\n\n"
|
458
|
+
f"Key mechanics:\n"
|
459
|
+
f"⢠'do' action: collect wood from trees, stone from deposits, food from cows/plants\n"
|
460
|
+
f"⢠'do' does nothing on grass/water - move to find resources\n"
|
461
|
+
f"⢠Craft progression: wood ā table ā wood_pickaxe ā stone ā stone_pickaxe ā iron tools\n"
|
462
|
+
f"⢠Sleep when energy low to restore and unlock wake_up achievement\n\n"
|
463
|
+
f"Choose 1-5 actions to execute in sequence. Focus on exploring to find resources and crafting tools to unlock achievements."
|
464
|
+
)
|
465
|
+
|
466
|
+
|
467
|
+
# --- Episode Runner ---
|
468
|
+
async def run_single_episode(
|
469
|
+
client: AsyncClient, agent: CrafterReActAgent, task_instance, instance_num: int
|
470
|
+
) -> Dict[str, Any]:
|
471
|
+
"""Run a single Crafter episode and return episode metrics."""
|
472
|
+
try:
|
473
|
+
# Create environment using the task instance
|
474
|
+
create_resp = await client.post(
|
475
|
+
f"/env/CrafterClassic/initialize",
|
476
|
+
json={"task_instance": await task_instance.serialize()},
|
477
|
+
)
|
478
|
+
|
479
|
+
if create_resp.status_code != 200:
|
480
|
+
print(
|
481
|
+
f" Instance {instance_num}: Failed to create environment - {create_resp.status_code}: {create_resp.text}"
|
482
|
+
)
|
483
|
+
return {
|
484
|
+
"eval_metric": 0.0,
|
485
|
+
"rubric": {},
|
486
|
+
"total_reward": 0.0,
|
487
|
+
"num_achievements": 0,
|
488
|
+
"terminated": False,
|
489
|
+
"error": True,
|
490
|
+
}
|
491
|
+
|
492
|
+
env_id = create_resp.json()["env_id"]
|
493
|
+
|
494
|
+
# Get initial observation
|
495
|
+
obs = create_resp.json()["observation"]
|
496
|
+
formatted_obs = agent.format_observation(obs)
|
497
|
+
|
498
|
+
# DEBUG: Print initial state (minimal)
|
499
|
+
print(
|
500
|
+
f"\n Instance {instance_num}: Starting Crafter survival ({task_instance.metadata.difficulty}, {agent.max_turns} turns max)"
|
501
|
+
)
|
502
|
+
|
503
|
+
# Track episode metrics
|
504
|
+
total_reward = 0.0
|
505
|
+
final_achievements = {}
|
506
|
+
num_achievements = 0
|
507
|
+
terminated = False
|
508
|
+
rollout_length = 0
|
509
|
+
|
510
|
+
# Run episode
|
511
|
+
for turn in range(agent.max_turns):
|
512
|
+
try:
|
513
|
+
# Get agent decision
|
514
|
+
action = await agent.decide(formatted_obs, agent.get_system_message(), turn)
|
515
|
+
# print(f" ā
Agent decision received: {action}")
|
516
|
+
|
517
|
+
# # DEBUG: Print agent decision with safer access
|
518
|
+
# try:
|
519
|
+
# actions = action.get('parameters', {}).get('actions', action.get('arguments', {}).get('actions', []))
|
520
|
+
# reasoning = action.get('parameters', {}).get('reasoning', action.get('arguments', {}).get('reasoning', 'no reasoning'))
|
521
|
+
# #print(f" Turn {turn+1}: Agent chose {actions} - {reasoning}")
|
522
|
+
# except Exception as e:
|
523
|
+
# print(f" Turn {turn+1}: Agent action structure: {action}")
|
524
|
+
# print(f" Error parsing action: {e}")
|
525
|
+
|
526
|
+
# Check for termination
|
527
|
+
if action["name"] == "terminate":
|
528
|
+
reason = action.get("parameters", {}).get(
|
529
|
+
"reason", action.get("arguments", {}).get("reason", "no reason given")
|
530
|
+
)
|
531
|
+
print(f" Agent terminated: {reason}")
|
532
|
+
break
|
533
|
+
|
534
|
+
# Execute actions in environment with safer access
|
535
|
+
action_sequence = action.get("parameters", {}).get(
|
536
|
+
"actions", action.get("arguments", {}).get("actions", [])
|
537
|
+
)
|
538
|
+
if not action_sequence:
|
539
|
+
print(f" ā ļø No actions found in: {action}")
|
540
|
+
continue
|
541
|
+
|
542
|
+
# Convert action names to integers using the proper action map
|
543
|
+
# Define the proper Crafter action mapping
|
544
|
+
CRAFTER_ACTION_MAP = {
|
545
|
+
"noop": 0,
|
546
|
+
"move_left": 1,
|
547
|
+
"move_right": 2,
|
548
|
+
"move_up": 3,
|
549
|
+
"move_down": 4,
|
550
|
+
"do": 5,
|
551
|
+
"sleep": 6,
|
552
|
+
"place_stone": 7,
|
553
|
+
"place_table": 8,
|
554
|
+
"place_furnace": 9,
|
555
|
+
"place_plant": 10,
|
556
|
+
"make_wood_pickaxe": 11,
|
557
|
+
"make_stone_pickaxe": 12,
|
558
|
+
"make_iron_pickaxe": 13,
|
559
|
+
"make_wood_sword": 14,
|
560
|
+
"make_stone_sword": 15,
|
561
|
+
"make_iron_sword": 16,
|
562
|
+
}
|
563
|
+
|
564
|
+
action_ints = []
|
565
|
+
for action_name in action_sequence:
|
566
|
+
if action_name in CRAFTER_ACTION_MAP:
|
567
|
+
action_int = CRAFTER_ACTION_MAP[action_name]
|
568
|
+
else:
|
569
|
+
action_int = 0 # Default to noop
|
570
|
+
action_ints.append(action_int)
|
571
|
+
|
572
|
+
# Execute each action individually (Crafter expects single actions)
|
573
|
+
for i, action_int in enumerate(action_ints):
|
574
|
+
step_resp = await client.post(
|
575
|
+
f"/env/CrafterClassic/step",
|
576
|
+
json={
|
577
|
+
"env_id": env_id,
|
578
|
+
"request_id": str(uuid.uuid4()),
|
579
|
+
"action": {
|
580
|
+
"tool_calls": [{"tool": "interact", "args": {"action": action_int}}]
|
581
|
+
},
|
582
|
+
},
|
583
|
+
)
|
584
|
+
|
585
|
+
if step_resp.status_code != 200:
|
586
|
+
print(
|
587
|
+
f" ā Action {i + 1} failed: {step_resp.status_code}: {step_resp.text}"
|
588
|
+
)
|
589
|
+
break
|
590
|
+
|
591
|
+
# Update observation after each action
|
592
|
+
obs = step_resp.json()["observation"]
|
593
|
+
|
594
|
+
# Check final response status
|
595
|
+
if step_resp.status_code != 200:
|
596
|
+
break
|
597
|
+
|
598
|
+
# Show final state after all actions
|
599
|
+
formatted_obs = agent.format_observation(obs)
|
600
|
+
step_count = obs.get("num_steps_taken", 0)
|
601
|
+
rollout_length = step_count
|
602
|
+
position = obs.get("player_position", [0, 0])
|
603
|
+
# print(f" Turn {turn+1}: Actions completed - Step: {step_count}, Position: {position}")
|
604
|
+
|
605
|
+
# Update history with safer access
|
606
|
+
reasoning = action.get("parameters", {}).get(
|
607
|
+
"reasoning", action.get("arguments", {}).get("reasoning", "")
|
608
|
+
)
|
609
|
+
agent.history.append(f"{', '.join(action_sequence)}: {reasoning[:50]}")
|
610
|
+
|
611
|
+
# Track episode progress - Use the FINAL observation from the last action
|
612
|
+
terminated = obs.get("terminated", False)
|
613
|
+
step_reward = obs.get("reward", 0.0)
|
614
|
+
total_reward += step_reward
|
615
|
+
achievements = obs.get("achievements") or obs.get("achievements_status", {})
|
616
|
+
|
617
|
+
# ALWAYS update final_achievements with the latest observation
|
618
|
+
final_achievements = achievements
|
619
|
+
|
620
|
+
num_achievements = sum(1 for v in achievements.values() if v) if achievements else 0
|
621
|
+
|
622
|
+
if terminated:
|
623
|
+
print(
|
624
|
+
f" ā
Instance {instance_num}: Episode completed! Achievements: {num_achievements}, Total reward: {total_reward:.3f}"
|
625
|
+
)
|
626
|
+
break
|
627
|
+
|
628
|
+
except Exception as e:
|
629
|
+
print(f" ā Error in turn {turn + 1}: {e}")
|
630
|
+
import traceback
|
631
|
+
|
632
|
+
traceback.print_exc()
|
633
|
+
break
|
634
|
+
|
635
|
+
# Cleanup
|
636
|
+
await client.post(f"/env/CrafterClassic/terminate", json={"env_id": env_id})
|
637
|
+
|
638
|
+
# Calculate K-weighted achievement reward
|
639
|
+
achievement_reward = 0.0
|
640
|
+
if final_achievements:
|
641
|
+
for achievement, unlocked in final_achievements.items():
|
642
|
+
if unlocked and achievement in ACHIEVEMENT_K_VALUES:
|
643
|
+
k_value = ACHIEVEMENT_K_VALUES[achievement]
|
644
|
+
achievement_reward += k_value * math.log(2) # log(1+1) for single achievement
|
645
|
+
|
646
|
+
# Use achievement reward as the total reward
|
647
|
+
total_reward = achievement_reward
|
648
|
+
|
649
|
+
# Calculate eval metric and rubric
|
650
|
+
eval_metric = float(num_achievements) # Simple metric: number of achievements
|
651
|
+
|
652
|
+
# Create rubric with specific achievement checks
|
653
|
+
rubric = {}
|
654
|
+
if final_achievements:
|
655
|
+
rubric = {
|
656
|
+
"collect_wood": 1.0 if final_achievements.get("collect_wood", False) else 0.0,
|
657
|
+
"collect_stone": 1.0 if final_achievements.get("collect_stone", False) else 0.0,
|
658
|
+
"collect_coal": 1.0 if final_achievements.get("collect_coal", False) else 0.0,
|
659
|
+
"collect_iron": 1.0 if final_achievements.get("collect_iron", False) else 0.0,
|
660
|
+
"collect_diamond": 1.0 if final_achievements.get("collect_diamond", False) else 0.0,
|
661
|
+
"place_table": 1.0 if final_achievements.get("place_table", False) else 0.0,
|
662
|
+
"place_furnace": 1.0 if final_achievements.get("place_furnace", False) else 0.0,
|
663
|
+
"make_wood_pickaxe": 1.0
|
664
|
+
if final_achievements.get("make_wood_pickaxe", False)
|
665
|
+
else 0.0,
|
666
|
+
"make_stone_pickaxe": 1.0
|
667
|
+
if final_achievements.get("make_stone_pickaxe", False)
|
668
|
+
else 0.0,
|
669
|
+
"make_iron_pickaxe": 1.0
|
670
|
+
if final_achievements.get("make_iron_pickaxe", False)
|
671
|
+
else 0.0,
|
672
|
+
"make_wood_sword": 1.0 if final_achievements.get("make_wood_sword", False) else 0.0,
|
673
|
+
"make_stone_sword": 1.0
|
674
|
+
if final_achievements.get("make_stone_sword", False)
|
675
|
+
else 0.0,
|
676
|
+
"make_iron_sword": 1.0 if final_achievements.get("make_iron_sword", False) else 0.0,
|
677
|
+
"defeat_skeleton": 1.0 if final_achievements.get("defeat_skeleton", False) else 0.0,
|
678
|
+
"defeat_zombie": 1.0 if final_achievements.get("defeat_zombie", False) else 0.0,
|
679
|
+
"wake_up": 1.0 if final_achievements.get("wake_up", False) else 0.0,
|
680
|
+
"eat_cow": 1.0 if final_achievements.get("eat_cow", False) else 0.0,
|
681
|
+
"eat_plant": 1.0 if final_achievements.get("eat_plant", False) else 0.0,
|
682
|
+
}
|
683
|
+
else:
|
684
|
+
# Default rubric with all zeros
|
685
|
+
rubric = {
|
686
|
+
"collect_wood": 0.0,
|
687
|
+
"collect_stone": 0.0,
|
688
|
+
"collect_coal": 0.0,
|
689
|
+
"collect_iron": 0.0,
|
690
|
+
"collect_diamond": 0.0,
|
691
|
+
"place_table": 0.0,
|
692
|
+
"place_furnace": 0.0,
|
693
|
+
"make_wood_pickaxe": 0.0,
|
694
|
+
"make_stone_pickaxe": 0.0,
|
695
|
+
"make_iron_pickaxe": 0.0,
|
696
|
+
"make_wood_sword": 0.0,
|
697
|
+
"make_stone_sword": 0.0,
|
698
|
+
"make_iron_sword": 0.0,
|
699
|
+
"defeat_skeleton": 0.0,
|
700
|
+
"defeat_zombie": 0.0,
|
701
|
+
"wake_up": 0.0,
|
702
|
+
"eat_cow": 0.0,
|
703
|
+
"eat_plant": 0.0,
|
704
|
+
}
|
705
|
+
|
706
|
+
return {
|
707
|
+
"eval_metric": eval_metric,
|
708
|
+
"rubric": rubric,
|
709
|
+
"total_reward": total_reward,
|
710
|
+
"num_achievements": num_achievements,
|
711
|
+
"achievements": final_achievements,
|
712
|
+
"rollout_length": rollout_length,
|
713
|
+
"terminated": terminated,
|
714
|
+
"error": False,
|
715
|
+
}
|
716
|
+
|
717
|
+
except Exception as e:
|
718
|
+
print(f" Instance {instance_num}: Error - {e}")
|
719
|
+
import traceback
|
720
|
+
|
721
|
+
traceback.print_exc()
|
722
|
+
return {
|
723
|
+
"eval_metric": 0.0,
|
724
|
+
"rubric": {},
|
725
|
+
"total_reward": 0.0,
|
726
|
+
"num_achievements": 0,
|
727
|
+
"terminated": False,
|
728
|
+
"error": True,
|
729
|
+
}
|
730
|
+
|
731
|
+
|
732
|
+
# --- Batch Evaluation ---
|
733
|
+
async def evaluate_crafter_batch() -> Dict[str, Any]:
|
734
|
+
"""Evaluate Crafter agent on multiple easy instances."""
|
735
|
+
print(f"šÆ Evaluating Crafter on {config.num_instances} {config.difficulty} instances...")
|
736
|
+
|
737
|
+
llm = LM(model_name=config.model_name, formatting_model_name=config.model_name, temperature=0.0)
|
738
|
+
|
739
|
+
# Get easy task instances using the taskset system
|
740
|
+
from synth_ai.environments.examples.crafter_classic.taskset import (
|
741
|
+
CrafterTaskInstance,
|
742
|
+
CrafterTaskInstanceMetadata,
|
743
|
+
)
|
744
|
+
from synth_ai.environments.tasks.core import Impetus, Intent
|
745
|
+
|
746
|
+
easy_task_instances = []
|
747
|
+
for seed in range(config.num_instances):
|
748
|
+
try:
|
749
|
+
metadata = CrafterTaskInstanceMetadata(
|
750
|
+
difficulty=config.difficulty,
|
751
|
+
seed=seed,
|
752
|
+
num_trees_radius=5, # Good for easy difficulty
|
753
|
+
num_cows_radius=2,
|
754
|
+
num_hostiles_radius=0, # No hostiles for easy
|
755
|
+
)
|
756
|
+
task_instance = CrafterTaskInstance(
|
757
|
+
id=uuid.uuid4(),
|
758
|
+
impetus=Impetus(
|
759
|
+
instructions=f"Survive and unlock achievements in an {config.difficulty} environment."
|
760
|
+
),
|
761
|
+
intent=Intent(rubric={}, gold_trajectories=None, gold_state_diff={}),
|
762
|
+
metadata=metadata,
|
763
|
+
is_reproducible=True,
|
764
|
+
initial_engine_snapshot=None,
|
765
|
+
)
|
766
|
+
easy_task_instances.append(task_instance)
|
767
|
+
except Exception as e:
|
768
|
+
print(f" ā ļø Failed to create task instance for seed {seed}: {e}")
|
769
|
+
continue
|
770
|
+
|
771
|
+
print(f" š Generated {len(easy_task_instances)} {config.difficulty} task instances")
|
772
|
+
|
773
|
+
async with AsyncClient(
|
774
|
+
base_url=config.service_base_url, timeout=config.service_timeout
|
775
|
+
) as client:
|
776
|
+
# Run trajectories in parallel batches of 4
|
777
|
+
batch_size = 4
|
778
|
+
all_results = []
|
779
|
+
|
780
|
+
for batch_start in range(0, len(easy_task_instances), batch_size):
|
781
|
+
batch_end = min(batch_start + batch_size, len(easy_task_instances))
|
782
|
+
batch_instances = easy_task_instances[batch_start:batch_end]
|
783
|
+
|
784
|
+
print(
|
785
|
+
f" š Running batch {batch_start // batch_size + 1} ({len(batch_instances)} episodes)..."
|
786
|
+
)
|
787
|
+
|
788
|
+
# Create tasks for this batch
|
789
|
+
batch_tasks = []
|
790
|
+
for i, task_instance in enumerate(batch_instances):
|
791
|
+
agent = CrafterReActAgent(llm, max_turns=config.max_turns, verbose=False)
|
792
|
+
batch_tasks.append(
|
793
|
+
run_single_episode(client, agent, task_instance, batch_start + i + 1)
|
794
|
+
)
|
795
|
+
|
796
|
+
# Run this batch in parallel with progress bar
|
797
|
+
batch_results = await tqdm_asyncio.gather(
|
798
|
+
*batch_tasks, desc=f"Batch {batch_start // batch_size + 1}", unit="episode"
|
799
|
+
)
|
800
|
+
all_results.extend(batch_results)
|
801
|
+
|
802
|
+
print(f" ā
Batch {batch_start // batch_size + 1} completed")
|
803
|
+
|
804
|
+
results = all_results
|
805
|
+
|
806
|
+
# Filter out error results
|
807
|
+
valid_results = [r for r in results if not r.get("error", False)]
|
808
|
+
|
809
|
+
if not valid_results:
|
810
|
+
return {
|
811
|
+
"eval_metrics": [],
|
812
|
+
"mean_eval_metric": 0.0,
|
813
|
+
"mean_rubric": {},
|
814
|
+
"num_episodes": 0,
|
815
|
+
}
|
816
|
+
|
817
|
+
# Extract eval metrics and rubrics
|
818
|
+
eval_metrics = [r["eval_metric"] for r in valid_results]
|
819
|
+
mean_eval_metric = sum(eval_metrics) / len(eval_metrics)
|
820
|
+
|
821
|
+
# --- Rollout length statistics ---
|
822
|
+
rollout_lengths = [r["rollout_length"] for r in valid_results]
|
823
|
+
sorted_lengths = sorted(rollout_lengths)
|
824
|
+
n_lengths = len(sorted_lengths)
|
825
|
+
# Median (Q2)
|
826
|
+
if n_lengths % 2 == 1:
|
827
|
+
q2_rollout = sorted_lengths[n_lengths // 2]
|
828
|
+
else:
|
829
|
+
q2_rollout = (sorted_lengths[n_lengths // 2 - 1] + sorted_lengths[n_lengths // 2]) / 2
|
830
|
+
# 90th percentile (P90)
|
831
|
+
p90_index = int(0.9 * (n_lengths - 1))
|
832
|
+
p90_rollout = sorted_lengths[p90_index]
|
833
|
+
max_rollout = sorted_lengths[-1]
|
834
|
+
|
835
|
+
# Calculate mean rubric values
|
836
|
+
all_rubric_keys = set()
|
837
|
+
for r in valid_results:
|
838
|
+
all_rubric_keys.update(r["rubric"].keys())
|
839
|
+
|
840
|
+
mean_rubric = {}
|
841
|
+
for key in all_rubric_keys:
|
842
|
+
values = [r["rubric"].get(key, 0.0) for r in valid_results]
|
843
|
+
mean_rubric[key] = sum(values) / len(values)
|
844
|
+
|
845
|
+
# Calculate shaped reward (training rubric)
|
846
|
+
# Count total achievements across all episodes
|
847
|
+
achievement_counts = {}
|
848
|
+
unique_achievements_per_trajectory = []
|
849
|
+
all_unique_achievements = set()
|
850
|
+
|
851
|
+
for result in valid_results:
|
852
|
+
achievements = result.get("achievements", {})
|
853
|
+
trajectory_achievements = set()
|
854
|
+
for achievement, unlocked in achievements.items():
|
855
|
+
if unlocked:
|
856
|
+
achievement_counts[achievement] = achievement_counts.get(achievement, 0) + 1
|
857
|
+
trajectory_achievements.add(achievement)
|
858
|
+
all_unique_achievements.add(achievement)
|
859
|
+
unique_achievements_per_trajectory.append(trajectory_achievements)
|
860
|
+
|
861
|
+
# Calculate shaped reward using the counts
|
862
|
+
shaped_reward_data = calculate_shaped_reward(achievement_counts)
|
863
|
+
|
864
|
+
# Calculate unique achievements by N trajectories
|
865
|
+
unique_achievements_by_n = {}
|
866
|
+
for n in range(1, len(valid_results) + 1):
|
867
|
+
unique_at_n = set()
|
868
|
+
for i in range(n):
|
869
|
+
unique_at_n.update(unique_achievements_per_trajectory[i])
|
870
|
+
unique_achievements_by_n[n] = len(unique_at_n)
|
871
|
+
|
872
|
+
# Create training rubric (normalized shaped reward components)
|
873
|
+
training_rubric = {}
|
874
|
+
total_episodes = len(valid_results)
|
875
|
+
if shaped_reward_data["breakdown"]:
|
876
|
+
for achievement, data in shaped_reward_data["breakdown"].items():
|
877
|
+
# Normalize by number of episodes for comparison
|
878
|
+
training_rubric[achievement] = data["contribution"] / total_episodes
|
879
|
+
|
880
|
+
return {
|
881
|
+
"eval_metrics": eval_metrics,
|
882
|
+
"mean_eval_metric": mean_eval_metric,
|
883
|
+
"mean_rubric": mean_rubric,
|
884
|
+
"achievement_counts": achievement_counts,
|
885
|
+
"shaped_reward_data": shaped_reward_data,
|
886
|
+
"training_rubric": training_rubric,
|
887
|
+
"unique_achievements_per_trajectory": unique_achievements_per_trajectory,
|
888
|
+
"all_unique_achievements": all_unique_achievements,
|
889
|
+
"unique_achievements_by_n": unique_achievements_by_n,
|
890
|
+
"num_episodes": len(valid_results),
|
891
|
+
"q2_rollout": q2_rollout,
|
892
|
+
"p90_rollout": p90_rollout,
|
893
|
+
"max_rollout": max_rollout,
|
894
|
+
}
|
895
|
+
|
896
|
+
|
897
|
+
async def main():
|
898
|
+
"""Run Crafter evaluation."""
|
899
|
+
# Configure logging to reduce verbosity
|
900
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
901
|
+
logging.getLogger("google_genai").setLevel(logging.WARNING)
|
902
|
+
logging.getLogger("google.generativeai").setLevel(logging.WARNING)
|
903
|
+
logging.getLogger("google_genai.models").setLevel(logging.WARNING)
|
904
|
+
logging.getLogger("google_genai.types").setLevel(logging.WARNING)
|
905
|
+
|
906
|
+
print(f"š® Crafter ReAct Agent Evaluation")
|
907
|
+
print(f"Model: {config.model_name}")
|
908
|
+
print(f"Service: {config.service_base_url}")
|
909
|
+
print(f"Instances: {config.num_instances}")
|
910
|
+
print(f"Max Turns: {config.max_turns}")
|
911
|
+
print(f"Difficulty: {config.difficulty}")
|
912
|
+
print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
913
|
+
print("=" * 50)
|
914
|
+
|
915
|
+
# Test service health
|
916
|
+
async with AsyncClient(base_url=config.service_base_url, timeout=10.0) as client:
|
917
|
+
try:
|
918
|
+
health_resp = await client.get("/health")
|
919
|
+
health_data = health_resp.json()
|
920
|
+
|
921
|
+
if "CrafterClassic" not in health_data.get("supported_environments", []):
|
922
|
+
print("ā CrafterClassic not available on service")
|
923
|
+
return
|
924
|
+
|
925
|
+
print("ā
Service health check passed")
|
926
|
+
|
927
|
+
except Exception as e:
|
928
|
+
print(f"ā Service health check failed: {e}")
|
929
|
+
return
|
930
|
+
|
931
|
+
# Run evaluation
|
932
|
+
try:
|
933
|
+
results = await evaluate_crafter_batch()
|
934
|
+
|
935
|
+
print("\n" + "=" * 80)
|
936
|
+
print("š FINAL CRAFTER EVALUATION RESULTS")
|
937
|
+
print("=" * 80)
|
938
|
+
|
939
|
+
# Print eval metrics
|
940
|
+
print(f"š EVAL METRICS:")
|
941
|
+
print(f" Episodes: {results['num_episodes']}")
|
942
|
+
print(f" Individual Scores: {[f'{x:.1f}' for x in results['eval_metrics']]}")
|
943
|
+
print(f" Mean Eval Metric: {results['mean_eval_metric']:.2f}")
|
944
|
+
|
945
|
+
# Print standard rubric results
|
946
|
+
print(f"\nšÆ STANDARD RUBRIC RESULTS:")
|
947
|
+
if results["mean_rubric"]:
|
948
|
+
for achievement, score in sorted(results["mean_rubric"].items()):
|
949
|
+
print(f" {achievement}: {score:.2f}")
|
950
|
+
else:
|
951
|
+
print(" No rubric data available")
|
952
|
+
|
953
|
+
# Print shaped reward results
|
954
|
+
print(f"\nšļø TRAINING EVAL SCORE (SHAPED REWARD):")
|
955
|
+
shaped_data = results.get("shaped_reward_data", {})
|
956
|
+
print(f" Total Shaped Reward: {shaped_data.get('total_shaped_reward', 0.0):.3f}")
|
957
|
+
|
958
|
+
# Print achievement counts and contributions
|
959
|
+
achievement_counts = results.get("achievement_counts", {})
|
960
|
+
if achievement_counts:
|
961
|
+
print(f"\n Achievement Counts Across All Episodes:")
|
962
|
+
for achievement, count in sorted(achievement_counts.items()):
|
963
|
+
k_value = ACHIEVEMENT_K_VALUES.get(achievement, 0.0)
|
964
|
+
contribution = k_value * math.log(count + 1) if count > 0 else 0.0
|
965
|
+
print(
|
966
|
+
f" {achievement}: {count} times (K={k_value:.1f}, contribution={contribution:.3f})"
|
967
|
+
)
|
968
|
+
else:
|
969
|
+
print(" No achievements unlocked")
|
970
|
+
|
971
|
+
# Print training rubric (normalized contributions)
|
972
|
+
print(f"\nšļø TRAINING RUBRIC (PER EPISODE):")
|
973
|
+
if results.get("training_rubric"):
|
974
|
+
for achievement, score in sorted(results["training_rubric"].items()):
|
975
|
+
print(f" {achievement}: {score:.3f}")
|
976
|
+
else:
|
977
|
+
print(" No training rubric data available")
|
978
|
+
|
979
|
+
# Print unique achievements analysis
|
980
|
+
print(f"\nš UNIQUE ACHIEVEMENTS ANALYSIS:")
|
981
|
+
all_unique = results.get("all_unique_achievements", set())
|
982
|
+
print(f" Total Unique Achievements Unlocked: {len(all_unique)}")
|
983
|
+
if all_unique:
|
984
|
+
print(f" Unique Achievements: {', '.join(sorted(all_unique))}")
|
985
|
+
|
986
|
+
# Print unique achievements by N trajectories
|
987
|
+
unique_by_n = results.get("unique_achievements_by_n", {})
|
988
|
+
if unique_by_n:
|
989
|
+
print(f"\nš UNIQUE ACHIEVEMENTS BY N TRAJECTORIES:")
|
990
|
+
for n in sorted(unique_by_n.keys()):
|
991
|
+
print(f" After {n} trajectories: {unique_by_n[n]} unique achievements")
|
992
|
+
|
993
|
+
# Calculate average achievements per trajectory
|
994
|
+
achievements_per_trajectory = [
|
995
|
+
len(achievements)
|
996
|
+
for achievements in results.get("unique_achievements_per_trajectory", [])
|
997
|
+
]
|
998
|
+
if achievements_per_trajectory:
|
999
|
+
avg_achievements = sum(achievements_per_trajectory) / len(achievements_per_trajectory)
|
1000
|
+
print(f"\nš TRAJECTORY ANALYSIS:")
|
1001
|
+
print(f" Average Achievements per Trajectory: {avg_achievements:.2f}")
|
1002
|
+
print(f" Achievements per Trajectory: {achievements_per_trajectory}")
|
1003
|
+
print(f" Best Trajectory: {max(achievements_per_trajectory)} achievements")
|
1004
|
+
print(f" Worst Trajectory: {min(achievements_per_trajectory)} achievements")
|
1005
|
+
|
1006
|
+
# Overall assessment
|
1007
|
+
print(f"\nš ASSESSMENT:")
|
1008
|
+
if results["mean_eval_metric"] >= 3.0:
|
1009
|
+
print("š Excellent performance - achieving multiple objectives!")
|
1010
|
+
elif results["mean_eval_metric"] >= 1.0:
|
1011
|
+
print("ā
Good performance - consistently achieving objectives!")
|
1012
|
+
elif results["mean_eval_metric"] >= 0.5:
|
1013
|
+
print("ā ļø Moderate performance - some achievements unlocked")
|
1014
|
+
else:
|
1015
|
+
print("š Learning phase - focus on basic survival and resource gathering")
|
1016
|
+
|
1017
|
+
# Output markdown table row for README collation
|
1018
|
+
print(f"\nš MARKDOWN TABLE ROW:")
|
1019
|
+
print(
|
1020
|
+
"| Model | Episodes | Mean Score | Avg Achievements | Unique Achievements | Shaped Reward | Mean K-Score | Q2 Len | P90 Len | Max Len |"
|
1021
|
+
)
|
1022
|
+
print(
|
1023
|
+
"|------------------|----------|------------|------------------|---------------------|---------------|--------------|--------|---------|---------|"
|
1024
|
+
)
|
1025
|
+
achievements_per_trajectory = [
|
1026
|
+
len(achievements)
|
1027
|
+
for achievements in results.get("unique_achievements_per_trajectory", [])
|
1028
|
+
]
|
1029
|
+
avg_achievements = (
|
1030
|
+
sum(achievements_per_trajectory) / len(achievements_per_trajectory)
|
1031
|
+
if achievements_per_trajectory
|
1032
|
+
else 0.0
|
1033
|
+
)
|
1034
|
+
total_unique = len(results.get("all_unique_achievements", set()))
|
1035
|
+
shaped_reward = results.get("shaped_reward_data", {}).get("total_shaped_reward", 0.0)
|
1036
|
+
mean_k_score = (
|
1037
|
+
shaped_reward / results["num_episodes"] if results["num_episodes"] > 0 else 0.0
|
1038
|
+
)
|
1039
|
+
q2_rollout = results.get("q2_rollout", 0)
|
1040
|
+
p90_rollout = results.get("p90_rollout", 0)
|
1041
|
+
max_rollout = results.get("max_rollout", 0)
|
1042
|
+
|
1043
|
+
print(
|
1044
|
+
f"| {config.model_name:<16} | {results['num_episodes']:>8} | {results['mean_eval_metric']:>10.2f} | {avg_achievements:>16.2f} | {total_unique:>19} | {shaped_reward:>13.3f} | {mean_k_score:>12.3f} | {q2_rollout:>6} | {p90_rollout:>7} | {max_rollout:>7} |"
|
1045
|
+
)
|
1046
|
+
|
1047
|
+
except Exception as e:
|
1048
|
+
print(f"ā Evaluation failed: {e}")
|
1049
|
+
|
1050
|
+
|
1051
|
+
if __name__ == "__main__":
|
1052
|
+
# Parse command line arguments
|
1053
|
+
parser = argparse.ArgumentParser(description="Run Crafter ReAct Agent Evaluation")
|
1054
|
+
parser.add_argument("--config", "-c", type=str, help="Path to TOML configuration file")
|
1055
|
+
parser.add_argument("--model", "-m", type=str, help="Model name (overrides config)")
|
1056
|
+
parser.add_argument("--episodes", "-e", type=int, help="Number of episodes (overrides config)")
|
1057
|
+
parser.add_argument(
|
1058
|
+
"--max-turns", "-t", type=int, help="Maximum turns per episode (overrides config)"
|
1059
|
+
)
|
1060
|
+
parser.add_argument("--difficulty", "-d", type=str, help="Difficulty level (overrides config)")
|
1061
|
+
|
1062
|
+
# Custom OpenAI endpoint support
|
1063
|
+
parser.add_argument(
|
1064
|
+
"--openai-base-url",
|
1065
|
+
type=str,
|
1066
|
+
help="Custom OpenAI-compatible base URL (e.g., https://lora-service.modal.run)",
|
1067
|
+
)
|
1068
|
+
parser.add_argument(
|
1069
|
+
"--openai-api-key",
|
1070
|
+
type=str,
|
1071
|
+
default="dummy",
|
1072
|
+
help="API key for custom endpoint (default: 'dummy')",
|
1073
|
+
)
|
1074
|
+
|
1075
|
+
args = parser.parse_args()
|
1076
|
+
|
1077
|
+
# Load configuration
|
1078
|
+
if args.config:
|
1079
|
+
config = CrafterConfig(args.config)
|
1080
|
+
else:
|
1081
|
+
# Try to load default config
|
1082
|
+
default_config_path = (
|
1083
|
+
Path(__file__).parent.parent.parent.parent / "evals" / "configs" / "crafter.toml"
|
1084
|
+
)
|
1085
|
+
if default_config_path.exists():
|
1086
|
+
config = CrafterConfig(str(default_config_path))
|
1087
|
+
else:
|
1088
|
+
config = CrafterConfig()
|
1089
|
+
|
1090
|
+
# Override with command line arguments
|
1091
|
+
if args.model:
|
1092
|
+
config.model_name = args.model
|
1093
|
+
if args.episodes:
|
1094
|
+
config.num_instances = args.episodes
|
1095
|
+
if args.max_turns:
|
1096
|
+
config.max_turns = args.max_turns
|
1097
|
+
if args.difficulty:
|
1098
|
+
config.difficulty = args.difficulty
|
1099
|
+
|
1100
|
+
# Configure custom OpenAI endpoint if provided
|
1101
|
+
if args.openai_base_url:
|
1102
|
+
config.set_custom_endpoint(args.openai_base_url, args.openai_api_key)
|
1103
|
+
|
1104
|
+
# Fail fast if model_name still missing
|
1105
|
+
if not config.model_name:
|
1106
|
+
raise ValueError(
|
1107
|
+
"CrafterConfig: 'model_name' must be specified in the TOML config or via --model CLI argument; no fallback default."
|
1108
|
+
)
|
1109
|
+
|
1110
|
+
asyncio.run(main())
|