synth-ai 0.1.9__py3-none-any.whl ā 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms ā lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms ā lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms ā lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms ā lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms ā lm}/config.py +2 -1
- synth_ai/{zyk/lms ā lm}/constants.py +2 -2
- synth_ai/{zyk/lms ā lm}/core/all.py +10 -10
- synth_ai/{zyk/lms ā lm}/core/main.py +57 -33
- synth_ai/{zyk/lms ā lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms ā lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms ā lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms ā lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms ā lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms ā lm}/vendors/core/gemini_api.py +37 -32
- synth_ai/{zyk/lms ā lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms ā lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms ā lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms ā lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms ā lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms ā lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms ā lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms ā lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms ā lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms ā lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms ā lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.1.9.dist-info/METADATA +0 -37
- synth_ai-0.1.9.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py ā environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching ā lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core ā lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost ā lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs ā lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors ā lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core ā lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local ā lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported ā lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.1.9.dist-info ā synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.9.dist-info ā synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.9.dist-info ā synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1112 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Test script to run ReAct agents against NetHack environment on synth service (port 8901)
|
4
|
+
Tests on multiple easy NetHack instances with enhanced debugging
|
5
|
+
"""
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import json
|
9
|
+
import uuid
|
10
|
+
from datetime import datetime
|
11
|
+
from typing import Dict, Any, Optional, List
|
12
|
+
from pydantic import BaseModel, Field
|
13
|
+
from httpx import AsyncClient
|
14
|
+
import sys
|
15
|
+
import os
|
16
|
+
from tqdm import tqdm
|
17
|
+
|
18
|
+
# Add the src directory to the path
|
19
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "src"))
|
20
|
+
|
21
|
+
from synth_ai.zyk import LM
|
22
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
23
|
+
|
24
|
+
|
25
|
+
# --- Configuration Class ---
|
26
|
+
class NetHackConfig:
|
27
|
+
"""Configuration for NetHack evaluation (mirrors CrafterConfig)."""
|
28
|
+
|
29
|
+
def __init__(self, config_path: Optional[str] = None):
|
30
|
+
# Default values
|
31
|
+
self.model_name = "gpt-4.1-mini"
|
32
|
+
self.num_instances = 2
|
33
|
+
self.max_turns = 40
|
34
|
+
self.difficulty = "beginner"
|
35
|
+
self.service_base_url = "http://localhost:8901"
|
36
|
+
self.service_timeout = 30.0
|
37
|
+
self.seed = 42
|
38
|
+
self.save_traces = True
|
39
|
+
self.save_detailed_results = True
|
40
|
+
|
41
|
+
# Load from TOML if supplied
|
42
|
+
if config_path and os.path.exists(config_path):
|
43
|
+
try:
|
44
|
+
import toml
|
45
|
+
|
46
|
+
cfg = toml.load(config_path)
|
47
|
+
|
48
|
+
eval_cfg = cfg.get("eval", {})
|
49
|
+
self.model_name = eval_cfg.get("model_name", self.model_name)
|
50
|
+
self.num_instances = eval_cfg.get("episodes", self.num_instances)
|
51
|
+
self.max_turns = eval_cfg.get("max_steps", self.max_turns)
|
52
|
+
self.difficulty = eval_cfg.get("difficulty", self.difficulty)
|
53
|
+
self.seed = eval_cfg.get("seed", self.seed)
|
54
|
+
|
55
|
+
svc_cfg = cfg.get("service", {})
|
56
|
+
self.service_base_url = svc_cfg.get("base_url", self.service_base_url)
|
57
|
+
self.service_timeout = svc_cfg.get("timeout", self.service_timeout)
|
58
|
+
|
59
|
+
out_cfg = cfg.get("output", {})
|
60
|
+
self.save_traces = out_cfg.get("save_traces", self.save_traces)
|
61
|
+
self.save_detailed_results = out_cfg.get(
|
62
|
+
"save_detailed_results", self.save_detailed_results
|
63
|
+
)
|
64
|
+
except Exception as e:
|
65
|
+
print(f"[WARNING] Failed to load config from {config_path}: {e}")
|
66
|
+
|
67
|
+
|
68
|
+
# Instantiate default config (may be overridden by CLI later)
|
69
|
+
config = NetHackConfig()
|
70
|
+
|
71
|
+
|
72
|
+
# Overwrite the original global constants to use config values (so rest of script works unchanged)
|
73
|
+
def _apply_config_to_globals(cfg: NetHackConfig):
|
74
|
+
globals()["MODEL_NAME"] = cfg.model_name
|
75
|
+
globals()["NUM_INSTANCES"] = cfg.num_instances
|
76
|
+
globals()["MAX_TURNS"] = cfg.max_turns
|
77
|
+
globals()["DIFFICULTY"] = cfg.difficulty
|
78
|
+
globals()["SERVICE_BASE_URL"] = cfg.service_base_url
|
79
|
+
|
80
|
+
|
81
|
+
_apply_config_to_globals(config)
|
82
|
+
|
83
|
+
# --- CLI Override (similar to Crafter script) ---
|
84
|
+
# CLI parsing moved to end of file after main() is defined
|
85
|
+
|
86
|
+
|
87
|
+
# --- Service Configuration ---
|
88
|
+
SERVICE_BASE_URL = "http://localhost:8901"
|
89
|
+
MODEL_NAME = "gpt-4.1-mini"
|
90
|
+
NUM_INSTANCES = 2
|
91
|
+
MAX_TURNS = 40
|
92
|
+
DIFFICULTY = "beginner" # beginner, beginner, intermediate, advanced, expert
|
93
|
+
|
94
|
+
|
95
|
+
# --- Tool Definitions ---
|
96
|
+
class NetHackActionArgs(BaseModel):
|
97
|
+
"""Arguments for nethack actions."""
|
98
|
+
|
99
|
+
actions: List[str] = Field(
|
100
|
+
description="List of 1-3 action names to execute in sequence (e.g., ['north', 'search', 'inventory'])"
|
101
|
+
)
|
102
|
+
reasoning: str = Field(description="Brief explanation of why these actions were chosen")
|
103
|
+
|
104
|
+
|
105
|
+
class TerminateArgs(BaseModel):
|
106
|
+
"""Arguments for termination."""
|
107
|
+
|
108
|
+
reason: str = Field(description="Reason for termination")
|
109
|
+
|
110
|
+
|
111
|
+
class NetHackActionTool(BaseTool):
|
112
|
+
"""Tool for performing actions in the NetHack environment."""
|
113
|
+
|
114
|
+
name: str = "interact"
|
115
|
+
arguments: type[BaseModel] = NetHackActionArgs
|
116
|
+
description: str = "Perform 1-3 actions in sequence in the NetHack environment."
|
117
|
+
|
118
|
+
|
119
|
+
class TerminateTool(BaseTool):
|
120
|
+
"""Tool to terminate the episode."""
|
121
|
+
|
122
|
+
name: str = "terminate"
|
123
|
+
arguments: type[BaseModel] = TerminateArgs
|
124
|
+
description: str = "End the episode when finished or no progress can be made."
|
125
|
+
|
126
|
+
|
127
|
+
# --- Base ReAct Agent ---
|
128
|
+
class BaseReActAgent:
|
129
|
+
"""Base ReAct agent for environment interaction."""
|
130
|
+
|
131
|
+
def __init__(self, llm: LM, max_turns: int = 30, verbose: bool = False):
|
132
|
+
self.llm = llm
|
133
|
+
self.max_turns = max_turns
|
134
|
+
self.verbose = verbose
|
135
|
+
self.history = []
|
136
|
+
self.system_name = "base-react-agent"
|
137
|
+
|
138
|
+
# Define tools in OpenAI format
|
139
|
+
self.tools = [
|
140
|
+
NetHackActionTool(),
|
141
|
+
TerminateTool(),
|
142
|
+
]
|
143
|
+
|
144
|
+
async def decide(self, obs: str, system_message: str, turn: int) -> Dict[str, Any]:
|
145
|
+
"""Get agent decision based on observation."""
|
146
|
+
# Create conversation context
|
147
|
+
context = f"Turn {turn + 1}/{self.max_turns}\n\n{obs}"
|
148
|
+
|
149
|
+
# Generate response using LLM
|
150
|
+
response_obj = await self.llm.respond_async(
|
151
|
+
system_message=system_message, user_message=context, tools=self.tools
|
152
|
+
)
|
153
|
+
|
154
|
+
tool_calls = response_obj.tool_calls
|
155
|
+
|
156
|
+
# Handle case where tool_calls is None or empty (graceful fallback)
|
157
|
+
if not tool_calls:
|
158
|
+
if self.verbose:
|
159
|
+
print(f"[WARNING] No tool calls returned by LLM, using default action")
|
160
|
+
return {
|
161
|
+
"name": "interact",
|
162
|
+
"parameters": {
|
163
|
+
"actions": ["inventory"],
|
164
|
+
"reasoning": "Default action - no tool call received",
|
165
|
+
},
|
166
|
+
}
|
167
|
+
|
168
|
+
tool_call_data = tool_calls[0]
|
169
|
+
|
170
|
+
# Handle both dict and object formats
|
171
|
+
if isinstance(tool_call_data, dict):
|
172
|
+
tool_name = tool_call_data["function"]["name"]
|
173
|
+
tool_args_str = tool_call_data["function"]["arguments"]
|
174
|
+
else:
|
175
|
+
tool_name = tool_call_data.function.name
|
176
|
+
tool_args_str = tool_call_data.function.arguments
|
177
|
+
|
178
|
+
tool_arguments = json.loads(tool_args_str)
|
179
|
+
|
180
|
+
return {"name": tool_name, "parameters": tool_arguments}
|
181
|
+
|
182
|
+
|
183
|
+
# --- NetHack ReAct Agent ---
|
184
|
+
class NetHackReActAgent(BaseReActAgent):
|
185
|
+
"""ReAct agent for NetHack environment."""
|
186
|
+
|
187
|
+
def __init__(self, llm: LM, max_turns: int = 30, verbose: bool = False):
|
188
|
+
super().__init__(llm, max_turns, verbose)
|
189
|
+
self.system_name = "nethack-react-agent"
|
190
|
+
|
191
|
+
def get_system_message(self) -> str:
|
192
|
+
return """You are an expert NetHack player. Your goal is to explore the dungeon, survive, and make progress.
|
193
|
+
|
194
|
+
MOVEMENT ACTIONS:
|
195
|
+
- north, south, east, west: Move in cardinal directions
|
196
|
+
- northeast, northwest, southeast, southwest: Move diagonally
|
197
|
+
- go_up, go_down: Use stairs (must be on < or > symbol)
|
198
|
+
|
199
|
+
EXPLORATION ACTIONS:
|
200
|
+
- search: Look for secret doors or traps
|
201
|
+
- open: Open doors
|
202
|
+
- close: Close doors
|
203
|
+
- look: Examine surroundings (FREE ACTION)
|
204
|
+
|
205
|
+
INVENTORY ACTIONS:
|
206
|
+
- inventory: Check your items (FREE ACTION)
|
207
|
+
- pickup: Pick up items
|
208
|
+
- drop: Drop items
|
209
|
+
- wear: Put on armor
|
210
|
+
- wield: Equip weapon
|
211
|
+
- eat: Consume food
|
212
|
+
- drink: Drink potion
|
213
|
+
- read: Read scroll
|
214
|
+
|
215
|
+
INTERACTION:
|
216
|
+
- wait: Rest for one turn
|
217
|
+
- chat: Talk to NPCs
|
218
|
+
- pay: Pay shopkeeper
|
219
|
+
- kick: Kick something
|
220
|
+
|
221
|
+
MAP SYMBOLS:
|
222
|
+
- @ = you (the player)
|
223
|
+
- . = floor
|
224
|
+
- # = wall/corridor
|
225
|
+
- + = closed door
|
226
|
+
- - = open door
|
227
|
+
- < = stairs up
|
228
|
+
- > = stairs down
|
229
|
+
- $ = gold
|
230
|
+
- % = food
|
231
|
+
- ! = potion
|
232
|
+
- ? = scroll
|
233
|
+
- / = wand
|
234
|
+
- ) = weapon
|
235
|
+
- [ = armor
|
236
|
+
- d,f = pets (dog/cat)
|
237
|
+
- Letters = monsters
|
238
|
+
|
239
|
+
STRATEGY:
|
240
|
+
1. Explore systematically to map the dungeon
|
241
|
+
2. Collect useful items and gold
|
242
|
+
3. Manage hunger by eating food
|
243
|
+
4. Fight weak monsters for experience
|
244
|
+
5. Use 'look' and 'inventory' frequently (they're free!)
|
245
|
+
6. Be cautious around unknown monsters
|
246
|
+
|
247
|
+
Remember: NetHack is complex but rewarding. Take your time and observe carefully."""
|
248
|
+
|
249
|
+
def format_observation(self, obs: Dict[str, Any]) -> str:
|
250
|
+
"""Format observation for NetHack."""
|
251
|
+
parts = []
|
252
|
+
|
253
|
+
if "ascii_map" in obs:
|
254
|
+
parts.append("ASCII Map:")
|
255
|
+
parts.append(obs["ascii_map"])
|
256
|
+
|
257
|
+
if "message" in obs and obs["message"]:
|
258
|
+
parts.append(f"Message: {obs['message']}")
|
259
|
+
|
260
|
+
if "character_stats" in obs:
|
261
|
+
stats = obs["character_stats"]
|
262
|
+
stat_items = []
|
263
|
+
for key, value in stats.items():
|
264
|
+
if key in ["HP", "level", "gold", "score", "turn"]:
|
265
|
+
stat_items.append(f"{key}: {value}")
|
266
|
+
if stat_items:
|
267
|
+
parts.append(f"Stats: {', '.join(stat_items)}")
|
268
|
+
|
269
|
+
if "inventory_summary" in obs:
|
270
|
+
parts.append(f"Inventory: {obs['inventory_summary']}")
|
271
|
+
|
272
|
+
if "hunger_status" in obs and obs["hunger_status"]:
|
273
|
+
parts.append(f"Hunger: {obs['hunger_status']}")
|
274
|
+
|
275
|
+
if "terminated" in obs:
|
276
|
+
parts.append(f"Terminated: {obs['terminated']}")
|
277
|
+
|
278
|
+
if "reward" in obs:
|
279
|
+
parts.append(f"Reward: {obs['reward']}")
|
280
|
+
|
281
|
+
return "\n".join(parts) if parts else "No formatted observation available"
|
282
|
+
|
283
|
+
|
284
|
+
# --- Episode Runner ---
|
285
|
+
async def run_single_episode(
|
286
|
+
client: AsyncClient,
|
287
|
+
agent: NetHackReActAgent,
|
288
|
+
task_instance,
|
289
|
+
instance_num: int,
|
290
|
+
progress_bar=None,
|
291
|
+
) -> Dict[str, Any]:
|
292
|
+
"""Run a single NetHack episode and return episode metrics."""
|
293
|
+
try:
|
294
|
+
# Create environment using the task instance
|
295
|
+
create_resp = await client.post(
|
296
|
+
f"/env/NetHack/initialize", json={"task_instance": await task_instance.serialize()}
|
297
|
+
)
|
298
|
+
|
299
|
+
if create_resp.status_code != 200:
|
300
|
+
print(
|
301
|
+
f" Instance {instance_num}: Failed to create environment - {create_resp.status_code}: {create_resp.text}"
|
302
|
+
)
|
303
|
+
return {"eval_metric": 0.0, "rubric": {}, "error": True}
|
304
|
+
|
305
|
+
env_id = create_resp.json()["env_id"]
|
306
|
+
|
307
|
+
# Get initial observation
|
308
|
+
obs = create_resp.json()["observation"]
|
309
|
+
formatted_obs = agent.format_observation(obs)
|
310
|
+
|
311
|
+
# DEBUG: Print initial state
|
312
|
+
# print(f"\n Instance {instance_num}: Starting NetHack adventure")
|
313
|
+
# print(f" Character: {task_instance.metadata.character_role}")
|
314
|
+
# print(f" Goal: Reach depth {task_instance.metadata.target_depth}")
|
315
|
+
|
316
|
+
# Track progress
|
317
|
+
initial_depth = 1
|
318
|
+
max_depth_reached = initial_depth
|
319
|
+
max_reward = 0.0
|
320
|
+
final_stats = {}
|
321
|
+
balrog_score = 0.0
|
322
|
+
balrog_total_reward = 0.0
|
323
|
+
achievements_unlocked = []
|
324
|
+
|
325
|
+
# Track additional progress metrics
|
326
|
+
monsters_killed = 0
|
327
|
+
items_picked_up = 0
|
328
|
+
scrolls_read = 0
|
329
|
+
potions_drunk = 0
|
330
|
+
rooms_explored = 0
|
331
|
+
secret_doors_found = 0
|
332
|
+
stairs_found = 0
|
333
|
+
traps_encountered = 0
|
334
|
+
spells_cast = 0
|
335
|
+
prayers_attempted = 0
|
336
|
+
max_score = 0
|
337
|
+
|
338
|
+
# Track shaped rewards (requires previous observation)
|
339
|
+
prev_obs = None
|
340
|
+
shaped_rewards = {
|
341
|
+
# Survival & Progress
|
342
|
+
"depth_delta_total": 0.0,
|
343
|
+
"stairs_seen_total": 0,
|
344
|
+
"turn_alive_total": 0.0,
|
345
|
+
"hp_gain_total": 0.0,
|
346
|
+
"hunger_ok_total": 0,
|
347
|
+
# Exploration
|
348
|
+
"new_tiles_total": 0,
|
349
|
+
"rooms_explored_delta_total": 0,
|
350
|
+
"secret_doors_delta_total": 0,
|
351
|
+
"traps_identified_delta_total": 0,
|
352
|
+
# Combat
|
353
|
+
"monsters_killed_delta_total": 0,
|
354
|
+
"dmg_dealt_total": 0.0,
|
355
|
+
"dmg_taken_total": 0.0,
|
356
|
+
# Resources
|
357
|
+
"gold_delta_total": 0,
|
358
|
+
"items_picked_delta_total": 0,
|
359
|
+
"scrolls_read_delta_total": 0,
|
360
|
+
"potions_quaffed_delta_total": 0,
|
361
|
+
"spells_cast_delta_total": 0,
|
362
|
+
# Skill/Utility
|
363
|
+
"first_prayer_achieved": False,
|
364
|
+
"first_spell_achieved": False,
|
365
|
+
"identify_item_total": 0,
|
366
|
+
# Achievements
|
367
|
+
"achievement_unlocked_total": 0,
|
368
|
+
# Intermediate reward accumulation
|
369
|
+
"total_intermediate_reward": 0.0,
|
370
|
+
}
|
371
|
+
|
372
|
+
# Run episode
|
373
|
+
for turn in range(agent.max_turns):
|
374
|
+
# Get agent decision
|
375
|
+
action = await agent.decide(formatted_obs, agent.get_system_message(), turn)
|
376
|
+
|
377
|
+
# Check for termination
|
378
|
+
if action["name"] == "terminate":
|
379
|
+
print(
|
380
|
+
f" Agent terminated: {action['parameters'].get('reason', 'no reason given')}"
|
381
|
+
)
|
382
|
+
break
|
383
|
+
|
384
|
+
# Execute actions in environment
|
385
|
+
action_sequence = action["parameters"]["actions"]
|
386
|
+
|
387
|
+
step_resp = await client.post(
|
388
|
+
f"/env/NetHack/step",
|
389
|
+
json={
|
390
|
+
"env_id": env_id,
|
391
|
+
"request_id": str(uuid.uuid4()),
|
392
|
+
"action": {
|
393
|
+
"tool_calls": [{"tool": "interact", "args": {"actions": action_sequence}}]
|
394
|
+
},
|
395
|
+
},
|
396
|
+
)
|
397
|
+
|
398
|
+
if step_resp.status_code != 200:
|
399
|
+
print(f" ā Step failed: {step_resp.status_code}: {step_resp.text}")
|
400
|
+
break
|
401
|
+
|
402
|
+
obs = step_resp.json()["observation"]
|
403
|
+
formatted_obs = agent.format_observation(obs)
|
404
|
+
|
405
|
+
# Calculate shaped rewards if we have a previous observation
|
406
|
+
if prev_obs is not None:
|
407
|
+
# --- Survival & Progress ---
|
408
|
+
current_depth = obs.get("character_stats", {}).get("dungeon_level", 1)
|
409
|
+
prev_depth = prev_obs.get("character_stats", {}).get("dungeon_level", 1)
|
410
|
+
depth_delta = current_depth - prev_depth
|
411
|
+
shaped_rewards["depth_delta_total"] += depth_delta
|
412
|
+
|
413
|
+
stairs_seen = int(obs.get("stairs_found", 0) > prev_obs.get("stairs_found", 0))
|
414
|
+
shaped_rewards["stairs_seen_total"] += stairs_seen
|
415
|
+
|
416
|
+
shaped_rewards["turn_alive_total"] += 0.01 # tiny tick reward every step survived
|
417
|
+
|
418
|
+
# HP calculations
|
419
|
+
current_hp = obs.get("character_stats", {}).get("hp", 1)
|
420
|
+
current_max_hp = obs.get("character_stats", {}).get("max_hp", 1)
|
421
|
+
prev_hp = prev_obs.get("character_stats", {}).get("hp", 1)
|
422
|
+
prev_max_hp = prev_obs.get("character_stats", {}).get("max_hp", 1)
|
423
|
+
|
424
|
+
if current_max_hp > 0 and prev_max_hp > 0:
|
425
|
+
hp_pct = current_hp / current_max_hp
|
426
|
+
prev_hp_pct = prev_hp / prev_max_hp
|
427
|
+
hp_gain = hp_pct - prev_hp_pct
|
428
|
+
shaped_rewards["hp_gain_total"] += hp_gain
|
429
|
+
|
430
|
+
hunger_ok = int(obs.get("hunger_status", "") in ("Not hungry", "Satiated", ""))
|
431
|
+
shaped_rewards["hunger_ok_total"] += hunger_ok
|
432
|
+
|
433
|
+
# --- Exploration ---
|
434
|
+
new_tiles = obs.get("exploration_stats", {}).get("new_tiles", 0)
|
435
|
+
shaped_rewards["new_tiles_total"] += new_tiles
|
436
|
+
|
437
|
+
rooms_explored_delta = obs.get("rooms_explored", 0) - prev_obs.get(
|
438
|
+
"rooms_explored", 0
|
439
|
+
)
|
440
|
+
shaped_rewards["rooms_explored_delta_total"] += rooms_explored_delta
|
441
|
+
|
442
|
+
secret_doors_delta = obs.get("secret_doors_found", 0) - prev_obs.get(
|
443
|
+
"secret_doors_found", 0
|
444
|
+
)
|
445
|
+
shaped_rewards["secret_doors_delta_total"] += secret_doors_delta
|
446
|
+
|
447
|
+
traps_identified_delta = obs.get("traps_encountered", 0) - prev_obs.get(
|
448
|
+
"traps_encountered", 0
|
449
|
+
)
|
450
|
+
shaped_rewards["traps_identified_delta_total"] += traps_identified_delta
|
451
|
+
|
452
|
+
# --- Combat ---
|
453
|
+
monsters_killed_delta = obs.get("achievement_stats", {}).get(
|
454
|
+
"monsters_killed", 0
|
455
|
+
) - prev_obs.get("achievement_stats", {}).get("monsters_killed", 0)
|
456
|
+
shaped_rewards["monsters_killed_delta_total"] += monsters_killed_delta
|
457
|
+
|
458
|
+
dmg_dealt = obs.get("combat", {}).get("damage_dealt", 0)
|
459
|
+
shaped_rewards["dmg_dealt_total"] += dmg_dealt
|
460
|
+
|
461
|
+
dmg_taken = obs.get("combat", {}).get("damage_taken", 0)
|
462
|
+
shaped_rewards["dmg_taken_total"] += dmg_taken
|
463
|
+
|
464
|
+
# --- Resources ---
|
465
|
+
gold_delta = obs.get("character_stats", {}).get("gold", 0) - prev_obs.get(
|
466
|
+
"character_stats", {}
|
467
|
+
).get("gold", 0)
|
468
|
+
shaped_rewards["gold_delta_total"] += gold_delta
|
469
|
+
|
470
|
+
items_picked_delta = obs.get("items_collected", 0) - prev_obs.get(
|
471
|
+
"items_collected", 0
|
472
|
+
)
|
473
|
+
shaped_rewards["items_picked_delta_total"] += items_picked_delta
|
474
|
+
|
475
|
+
scrolls_read_delta = obs.get("scrolls_read", 0) - prev_obs.get("scrolls_read", 0)
|
476
|
+
shaped_rewards["scrolls_read_delta_total"] += scrolls_read_delta
|
477
|
+
|
478
|
+
potions_quaffed_delta = obs.get("potions_drunk", 0) - prev_obs.get(
|
479
|
+
"potions_drunk", 0
|
480
|
+
)
|
481
|
+
shaped_rewards["potions_quaffed_delta_total"] += potions_quaffed_delta
|
482
|
+
|
483
|
+
spells_cast_delta = obs.get("spells_cast", 0) - prev_obs.get("spells_cast", 0)
|
484
|
+
shaped_rewards["spells_cast_delta_total"] += spells_cast_delta
|
485
|
+
|
486
|
+
# --- Skill/Utility ---
|
487
|
+
if (
|
488
|
+
obs.get("prayers_attempted", 0) > 0
|
489
|
+
and prev_obs.get("prayers_attempted", 0) == 0
|
490
|
+
):
|
491
|
+
shaped_rewards["first_prayer_achieved"] = True
|
492
|
+
|
493
|
+
if spells_cast_delta > 0 and prev_obs.get("spells_cast", 0) == 0:
|
494
|
+
shaped_rewards["first_spell_achieved"] = True
|
495
|
+
|
496
|
+
message = obs.get("message", "")
|
497
|
+
if isinstance(message, bytes):
|
498
|
+
message = message.decode("ascii", errors="ignore").strip("\x00")
|
499
|
+
if "You identify" in message:
|
500
|
+
shaped_rewards["identify_item_total"] += 1
|
501
|
+
|
502
|
+
# --- Achievements ---
|
503
|
+
current_achievements = obs.get("achievements_unlocked", {})
|
504
|
+
prev_achievements = prev_obs.get("achievements_unlocked", {})
|
505
|
+
achievement_unlocked = sum(
|
506
|
+
int(v and not prev_achievements.get(k, False))
|
507
|
+
for k, v in current_achievements.items()
|
508
|
+
)
|
509
|
+
shaped_rewards["achievement_unlocked_total"] += achievement_unlocked
|
510
|
+
|
511
|
+
# --- Calculate intermediate reward ---
|
512
|
+
intermediate_reward = (
|
513
|
+
1.0 * depth_delta
|
514
|
+
+ 0.2 * new_tiles
|
515
|
+
+ 2.0 * monsters_killed_delta
|
516
|
+
- 0.5 * dmg_taken / 10
|
517
|
+
+ 0.1 * gold_delta
|
518
|
+
+ 5.0 * achievement_unlocked
|
519
|
+
)
|
520
|
+
shaped_rewards["total_intermediate_reward"] += intermediate_reward
|
521
|
+
|
522
|
+
# Store current observation as previous for next iteration
|
523
|
+
prev_obs = obs.copy() if obs else None
|
524
|
+
|
525
|
+
# Track progress
|
526
|
+
if "character_stats" in obs:
|
527
|
+
final_stats = obs["character_stats"]
|
528
|
+
if "dungeon_level" in final_stats:
|
529
|
+
current_depth = final_stats["dungeon_level"]
|
530
|
+
max_depth_reached = max(max_depth_reached, current_depth)
|
531
|
+
|
532
|
+
reward = obs.get("reward", 0.0)
|
533
|
+
max_reward = max(max_reward, reward)
|
534
|
+
|
535
|
+
# Track achievements and Balrog rewards (like in main agent)
|
536
|
+
if "achievements_unlocked" in obs:
|
537
|
+
for ach, unlocked in obs["achievements_unlocked"].items():
|
538
|
+
if unlocked and ach not in achievements_unlocked:
|
539
|
+
achievements_unlocked.append(ach)
|
540
|
+
|
541
|
+
if "balrog_total_reward" in obs:
|
542
|
+
balrog_total_reward = obs["balrog_total_reward"]
|
543
|
+
|
544
|
+
if "achievement_stats" in obs and "balrog_score" in obs["achievement_stats"]:
|
545
|
+
balrog_score = obs["achievement_stats"]["balrog_score"]
|
546
|
+
|
547
|
+
# Track additional progress metrics from achievement stats
|
548
|
+
if "achievement_stats" in obs:
|
549
|
+
ach_stats = obs["achievement_stats"]
|
550
|
+
monsters_killed = ach_stats.get("monsters_killed", 0)
|
551
|
+
items_picked_up = ach_stats.get("items_collected", 0)
|
552
|
+
rooms_explored = ach_stats.get("rooms_explored", 0)
|
553
|
+
secret_doors_found = ach_stats.get("secret_doors_found", 0)
|
554
|
+
stairs_found = ach_stats.get("stairs_found", 0)
|
555
|
+
|
556
|
+
# Track score progression
|
557
|
+
current_score = obs.get("score", 0)
|
558
|
+
max_score = max(max_score, current_score)
|
559
|
+
|
560
|
+
# Parse message for additional events
|
561
|
+
message = obs.get("message", "")
|
562
|
+
if isinstance(message, bytes):
|
563
|
+
message = message.decode("ascii", errors="ignore").strip("\x00")
|
564
|
+
|
565
|
+
# Look for specific events in messages
|
566
|
+
if "You read" in message:
|
567
|
+
scrolls_read += 1
|
568
|
+
elif "You drink" in message:
|
569
|
+
potions_drunk += 1
|
570
|
+
elif "You cast" in message or "spell" in message.lower():
|
571
|
+
spells_cast += 1
|
572
|
+
elif "You pray" in message:
|
573
|
+
prayers_attempted += 1
|
574
|
+
elif "trap" in message.lower():
|
575
|
+
traps_encountered += 1
|
576
|
+
|
577
|
+
# Check if episode ended
|
578
|
+
terminated = obs.get("terminated", False)
|
579
|
+
|
580
|
+
if terminated:
|
581
|
+
print(
|
582
|
+
f" š Instance {instance_num}: Episode ended at depth {max_depth_reached}, reward: {max_reward:.3f}"
|
583
|
+
)
|
584
|
+
break
|
585
|
+
|
586
|
+
# Update progress bar
|
587
|
+
if progress_bar is not None:
|
588
|
+
progress_bar.update(1)
|
589
|
+
|
590
|
+
# Cleanup
|
591
|
+
await client.post(f"/env/NetHack/terminate", json={"env_id": env_id})
|
592
|
+
|
593
|
+
# Ensure progress bar completes
|
594
|
+
if progress_bar is not None:
|
595
|
+
progress_bar.n = progress_bar.total
|
596
|
+
progress_bar.close()
|
597
|
+
|
598
|
+
# Calculate eval metric and rubric
|
599
|
+
target_depth = task_instance.metadata.target_depth
|
600
|
+
|
601
|
+
# Balrog score: Use proper score from observation (like in main agent)
|
602
|
+
# This is the standard NetHack evaluation metric
|
603
|
+
|
604
|
+
# Eval metric is the normalized Balrog score (0-1)
|
605
|
+
eval_metric = balrog_score / 100.0
|
606
|
+
|
607
|
+
# Create rubric with specific achievements
|
608
|
+
rubric = {
|
609
|
+
# Core progression metrics
|
610
|
+
"reached_target_depth": 1.0 if max_depth_reached >= target_depth else 0.0,
|
611
|
+
"depth_progress": min(1.0, max_depth_reached / target_depth),
|
612
|
+
"gained_experience": 1.0 if final_stats.get("experience", 0) > 0 else 0.0,
|
613
|
+
"collected_gold": 1.0 if final_stats.get("gold", 0) > 100 else 0.0,
|
614
|
+
"gained_levels": 1.0 if final_stats.get("level", 1) > 1 else 0.0,
|
615
|
+
"survived_turns": min(1.0, len(agent.history) / 20.0), # Normalize to 20 turns
|
616
|
+
"positive_reward": 1.0 if max_reward > 0 else 0.0,
|
617
|
+
"achievement_fraction": len(achievements_unlocked)
|
618
|
+
/ 100.0, # Core Balrog metric (approximated)
|
619
|
+
# Combat and interaction metrics
|
620
|
+
"monsters_defeated": min(1.0, monsters_killed / 5.0), # Normalize to 5 kills
|
621
|
+
"items_collected": min(1.0, items_picked_up / 10.0), # Normalize to 10 items
|
622
|
+
"scrolls_used": min(1.0, scrolls_read / 3.0), # Normalize to 3 scrolls
|
623
|
+
"potions_used": min(1.0, potions_drunk / 2.0), # Normalize to 2 potions
|
624
|
+
"spells_cast": min(1.0, spells_cast / 2.0), # Normalize to 2 spells
|
625
|
+
# Exploration metrics
|
626
|
+
"rooms_explored": min(1.0, rooms_explored / 5.0), # Normalize to 5 rooms
|
627
|
+
"secret_doors_found": 1.0 if secret_doors_found > 0 else 0.0,
|
628
|
+
"stairs_found": 1.0 if stairs_found > 0 else 0.0,
|
629
|
+
"traps_encountered": 1.0 if traps_encountered > 0 else 0.0,
|
630
|
+
# Advanced metrics
|
631
|
+
"prayers_attempted": 1.0 if prayers_attempted > 0 else 0.0,
|
632
|
+
"score_progress": min(1.0, max_score / 100.0), # Normalize to 100 points
|
633
|
+
"active_exploration": 1.0
|
634
|
+
if (rooms_explored + secret_doors_found + stairs_found) > 0
|
635
|
+
else 0.0,
|
636
|
+
"item_interaction": 1.0 if (scrolls_read + potions_drunk + spells_cast) > 0 else 0.0,
|
637
|
+
# --- Shaped Rewards ---
|
638
|
+
# Survival & Progress
|
639
|
+
"depth_progress_reward": max(0.0, shaped_rewards["depth_delta_total"]),
|
640
|
+
"stairs_discovery_reward": min(1.0, shaped_rewards["stairs_seen_total"] / 5.0),
|
641
|
+
"survival_reward": min(
|
642
|
+
1.0, shaped_rewards["turn_alive_total"] / 1.0
|
643
|
+
), # Normalize to 1.0 for 100 turns
|
644
|
+
"hp_management_reward": max(0.0, shaped_rewards["hp_gain_total"]),
|
645
|
+
"hunger_management_reward": min(
|
646
|
+
1.0, shaped_rewards["hunger_ok_total"] / (len(agent.history) or 1)
|
647
|
+
),
|
648
|
+
# Exploration
|
649
|
+
"new_tiles_reward": min(
|
650
|
+
1.0, shaped_rewards["new_tiles_total"] / 100.0
|
651
|
+
), # Normalize to 100 tiles
|
652
|
+
"room_discovery_reward": min(1.0, shaped_rewards["rooms_explored_delta_total"] / 5.0),
|
653
|
+
"secret_discovery_reward": min(1.0, shaped_rewards["secret_doors_delta_total"] / 3.0),
|
654
|
+
"trap_discovery_reward": min(1.0, shaped_rewards["traps_identified_delta_total"] / 3.0),
|
655
|
+
# Combat
|
656
|
+
"combat_success_reward": min(1.0, shaped_rewards["monsters_killed_delta_total"] / 5.0),
|
657
|
+
"damage_dealt_reward": min(1.0, shaped_rewards["dmg_dealt_total"] / 50.0),
|
658
|
+
"damage_avoided_reward": max(0.0, 1.0 - shaped_rewards["dmg_taken_total"] / 50.0),
|
659
|
+
# Resources
|
660
|
+
"wealth_accumulation_reward": min(1.0, shaped_rewards["gold_delta_total"] / 100.0),
|
661
|
+
"item_collection_reward": min(1.0, shaped_rewards["items_picked_delta_total"] / 10.0),
|
662
|
+
"scroll_usage_reward": min(1.0, shaped_rewards["scrolls_read_delta_total"] / 3.0),
|
663
|
+
"potion_usage_reward": min(1.0, shaped_rewards["potions_quaffed_delta_total"] / 3.0),
|
664
|
+
"spell_usage_reward": min(1.0, shaped_rewards["spells_cast_delta_total"] / 3.0),
|
665
|
+
# Skill/Utility
|
666
|
+
"first_prayer_reward": 1.0 if shaped_rewards["first_prayer_achieved"] else 0.0,
|
667
|
+
"first_spell_reward": 1.0 if shaped_rewards["first_spell_achieved"] else 0.0,
|
668
|
+
"identification_reward": min(1.0, shaped_rewards["identify_item_total"] / 3.0),
|
669
|
+
# Achievements
|
670
|
+
"achievement_unlock_reward": min(
|
671
|
+
1.0, shaped_rewards["achievement_unlocked_total"] / 10.0
|
672
|
+
),
|
673
|
+
# Overall shaped reward
|
674
|
+
"total_intermediate_reward": shaped_rewards["total_intermediate_reward"],
|
675
|
+
"normalized_intermediate_reward": min(
|
676
|
+
1.0, max(0.0, shaped_rewards["total_intermediate_reward"] / 20.0)
|
677
|
+
),
|
678
|
+
}
|
679
|
+
|
680
|
+
# Remove or mark irrelevant rubric keys
|
681
|
+
irrelevant_rubric = {}
|
682
|
+
for k in list(rubric.keys()):
|
683
|
+
if k in IRRELEVANT_RUBRIC_KEYS:
|
684
|
+
irrelevant_rubric[k] = rubric.pop(k)
|
685
|
+
|
686
|
+
# Success determination
|
687
|
+
success = max_depth_reached >= target_depth or max_reward > 10.0 or balrog_score > 5.0
|
688
|
+
|
689
|
+
if success:
|
690
|
+
print(
|
691
|
+
f" ā
Instance {instance_num}: SUCCESS! Depth {max_depth_reached}, Balrog score: {balrog_score:.0f}"
|
692
|
+
)
|
693
|
+
else:
|
694
|
+
print(
|
695
|
+
f" ā Instance {instance_num}: Partial progress - depth {max_depth_reached}/{target_depth}, Balrog score: {balrog_score:.0f}"
|
696
|
+
)
|
697
|
+
|
698
|
+
return {
|
699
|
+
"eval_metric": eval_metric,
|
700
|
+
"rubric": rubric,
|
701
|
+
"max_depth_reached": max_depth_reached,
|
702
|
+
"target_depth": target_depth,
|
703
|
+
"max_reward": max_reward,
|
704
|
+
"balrog_score": balrog_score,
|
705
|
+
"balrog_total_reward": balrog_total_reward,
|
706
|
+
"achievements_unlocked": achievements_unlocked,
|
707
|
+
"final_stats": final_stats,
|
708
|
+
"success": success,
|
709
|
+
"error": False,
|
710
|
+
# Additional progress metrics
|
711
|
+
"monsters_killed": monsters_killed,
|
712
|
+
"items_picked_up": items_picked_up,
|
713
|
+
"scrolls_read": scrolls_read,
|
714
|
+
"potions_drunk": potions_drunk,
|
715
|
+
"rooms_explored": rooms_explored,
|
716
|
+
"secret_doors_found": secret_doors_found,
|
717
|
+
"stairs_found": stairs_found,
|
718
|
+
"traps_encountered": traps_encountered,
|
719
|
+
"spells_cast": spells_cast,
|
720
|
+
"prayers_attempted": prayers_attempted,
|
721
|
+
"max_score": max_score,
|
722
|
+
# Shaped rewards
|
723
|
+
"shaped_rewards": shaped_rewards,
|
724
|
+
"irrelevant_rubric": irrelevant_rubric,
|
725
|
+
}
|
726
|
+
|
727
|
+
except Exception as e:
|
728
|
+
print(f" Instance {instance_num}: Error - {e}")
|
729
|
+
import traceback
|
730
|
+
|
731
|
+
traceback.print_exc()
|
732
|
+
return {"eval_metric": 0.0, "rubric": {}, "error": True}
|
733
|
+
|
734
|
+
|
735
|
+
# --- Batch Evaluation ---
|
736
|
+
async def evaluate_nethack_batch() -> Dict[str, Any]:
|
737
|
+
"""Evaluate NetHack agent on multiple easy instances."""
|
738
|
+
print(f"šÆ Evaluating NetHack on {NUM_INSTANCES} {DIFFICULTY} instances...")
|
739
|
+
|
740
|
+
llm = LM(model_name=MODEL_NAME, formatting_model_name=MODEL_NAME, temperature=0.0)
|
741
|
+
|
742
|
+
# Get task instances using the taskset system
|
743
|
+
from synth_ai.environments.examples.nethack.taskset import create_nethack_taskset
|
744
|
+
|
745
|
+
taskset = await create_nethack_taskset()
|
746
|
+
|
747
|
+
# Filter for the desired difficulty
|
748
|
+
task_instances = [inst for inst in taskset.instances if inst.metadata.difficulty == DIFFICULTY][
|
749
|
+
:NUM_INSTANCES
|
750
|
+
]
|
751
|
+
|
752
|
+
if len(task_instances) < NUM_INSTANCES:
|
753
|
+
print(f" ā ļø Only found {len(task_instances)} {DIFFICULTY} instances, using all available")
|
754
|
+
|
755
|
+
print(f" š Using {len(task_instances)} {DIFFICULTY} task instances")
|
756
|
+
|
757
|
+
async with AsyncClient(
|
758
|
+
base_url=SERVICE_BASE_URL, timeout=60.0
|
759
|
+
) as client: # Longer timeout for NetHack
|
760
|
+
tasks = []
|
761
|
+
bars = []
|
762
|
+
for i, task_instance in enumerate(task_instances):
|
763
|
+
bar = tqdm(total=MAX_TURNS, desc=f"Ep {i + 1}", position=i, leave=True)
|
764
|
+
bars.append(bar)
|
765
|
+
agent = NetHackReActAgent(llm, max_turns=MAX_TURNS, verbose=False)
|
766
|
+
tasks.append(run_single_episode(client, agent, task_instance, i + 1, bar))
|
767
|
+
|
768
|
+
results = await asyncio.gather(*tasks)
|
769
|
+
|
770
|
+
# Filter out error results
|
771
|
+
valid_results = [r for r in results if not r.get("error", False)]
|
772
|
+
|
773
|
+
if not valid_results:
|
774
|
+
return {
|
775
|
+
"eval_metrics": [],
|
776
|
+
"mean_eval_metric": 0.0,
|
777
|
+
"mean_rubric": {},
|
778
|
+
"num_episodes": 0,
|
779
|
+
}
|
780
|
+
|
781
|
+
# Extract eval metrics and rubrics
|
782
|
+
eval_metrics = [r["eval_metric"] for r in valid_results]
|
783
|
+
mean_eval_metric = sum(eval_metrics) / len(eval_metrics)
|
784
|
+
|
785
|
+
# Extract Balrog scores
|
786
|
+
balrog_scores = [r.get("balrog_score", 0.0) for r in valid_results]
|
787
|
+
mean_balrog_score = sum(balrog_scores) / len(balrog_scores) if balrog_scores else 0.0
|
788
|
+
|
789
|
+
# Extract Balrog total rewards
|
790
|
+
balrog_total_rewards = [r.get("balrog_total_reward", 0.0) for r in valid_results]
|
791
|
+
mean_balrog_total_reward = (
|
792
|
+
sum(balrog_total_rewards) / len(balrog_total_rewards) if balrog_total_rewards else 0.0
|
793
|
+
)
|
794
|
+
|
795
|
+
# Extract additional progress metrics
|
796
|
+
progress_metrics = {
|
797
|
+
"monsters_killed": [r.get("monsters_killed", 0) for r in valid_results],
|
798
|
+
"items_picked_up": [r.get("items_picked_up", 0) for r in valid_results],
|
799
|
+
"scrolls_read": [r.get("scrolls_read", 0) for r in valid_results],
|
800
|
+
"potions_drunk": [r.get("potions_drunk", 0) for r in valid_results],
|
801
|
+
"rooms_explored": [r.get("rooms_explored", 0) for r in valid_results],
|
802
|
+
"secret_doors_found": [r.get("secret_doors_found", 0) for r in valid_results],
|
803
|
+
"stairs_found": [r.get("stairs_found", 0) for r in valid_results],
|
804
|
+
"traps_encountered": [r.get("traps_encountered", 0) for r in valid_results],
|
805
|
+
"spells_cast": [r.get("spells_cast", 0) for r in valid_results],
|
806
|
+
"prayers_attempted": [r.get("prayers_attempted", 0) for r in valid_results],
|
807
|
+
"max_score": [r.get("max_score", 0) for r in valid_results],
|
808
|
+
}
|
809
|
+
|
810
|
+
# Calculate means for progress metrics
|
811
|
+
mean_progress_metrics = {}
|
812
|
+
for key, values in progress_metrics.items():
|
813
|
+
mean_progress_metrics[key] = sum(values) / len(values) if values else 0.0
|
814
|
+
|
815
|
+
# Extract shaped rewards
|
816
|
+
shaped_rewards_summary = {}
|
817
|
+
irrelevant_shaped_summary = {}
|
818
|
+
if valid_results and "shaped_rewards" in valid_results[0]:
|
819
|
+
shaped_reward_keys = valid_results[0]["shaped_rewards"].keys()
|
820
|
+
for key in shaped_reward_keys:
|
821
|
+
values = [r.get("shaped_rewards", {}).get(key, 0) for r in valid_results]
|
822
|
+
if isinstance(values[0], bool):
|
823
|
+
avg_value = sum(values) / len(values) # Fraction of episodes
|
824
|
+
else:
|
825
|
+
avg_value = sum(values) / len(values) if values else 0.0
|
826
|
+
|
827
|
+
if key in IRRELEVANT_RUBRIC_KEYS:
|
828
|
+
irrelevant_shaped_summary[key] = avg_value
|
829
|
+
else:
|
830
|
+
shaped_rewards_summary[key] = avg_value
|
831
|
+
|
832
|
+
# Calculate individual relevant shaped rewards sums
|
833
|
+
individual_relevant_sums = []
|
834
|
+
if valid_results and "shaped_rewards" in valid_results[0]:
|
835
|
+
for result in valid_results:
|
836
|
+
episode_shaped_rewards = result.get("shaped_rewards", {})
|
837
|
+
relevant_sum = sum(
|
838
|
+
v for k, v in episode_shaped_rewards.items() if k not in IRRELEVANT_RUBRIC_KEYS
|
839
|
+
)
|
840
|
+
individual_relevant_sums.append(relevant_sum)
|
841
|
+
|
842
|
+
# Calculate mean of relevant shaped rewards sums
|
843
|
+
relevant_shaped_rewards_sum = (
|
844
|
+
sum(individual_relevant_sums) / len(individual_relevant_sums)
|
845
|
+
if individual_relevant_sums
|
846
|
+
else 0.0
|
847
|
+
)
|
848
|
+
|
849
|
+
# Calculate individual relevant rubric sums
|
850
|
+
individual_relevant_rubric_sums = []
|
851
|
+
for result in valid_results:
|
852
|
+
episode_rubric = result.get("rubric", {})
|
853
|
+
relevant_rubric_sum = sum(
|
854
|
+
v for k, v in episode_rubric.items() if k not in IRRELEVANT_RUBRIC_KEYS
|
855
|
+
)
|
856
|
+
individual_relevant_rubric_sums.append(relevant_rubric_sum)
|
857
|
+
|
858
|
+
# Calculate mean of relevant rubric sums
|
859
|
+
relevant_rubric_sum = (
|
860
|
+
sum(individual_relevant_rubric_sums) / len(individual_relevant_rubric_sums)
|
861
|
+
if individual_relevant_rubric_sums
|
862
|
+
else 0.0
|
863
|
+
)
|
864
|
+
|
865
|
+
# Calculate mean rubric values (excluding irrelevant)
|
866
|
+
all_rubric_keys = set()
|
867
|
+
for r in valid_results:
|
868
|
+
all_rubric_keys.update(
|
869
|
+
[k for k in r["rubric"].keys() if k not in IRRELEVANT_RUBRIC_KEYS]
|
870
|
+
)
|
871
|
+
|
872
|
+
mean_rubric = {}
|
873
|
+
for key in all_rubric_keys:
|
874
|
+
values = [r["rubric"].get(key, 0.0) for r in valid_results]
|
875
|
+
mean_rubric[key] = sum(values) / len(values)
|
876
|
+
|
877
|
+
# Collect irrelevant rubric metrics summary
|
878
|
+
irrelevant_summary = {}
|
879
|
+
for key in IRRELEVANT_RUBRIC_KEYS:
|
880
|
+
vals = [r.get("irrelevant_rubric", {}).get(key, 0.0) for r in valid_results]
|
881
|
+
irrelevant_summary[key] = sum(vals) / len(vals) if vals else 0.0
|
882
|
+
|
883
|
+
return {
|
884
|
+
"eval_metrics": eval_metrics,
|
885
|
+
"mean_eval_metric": mean_eval_metric,
|
886
|
+
"balrog_scores": balrog_scores,
|
887
|
+
"mean_balrog_score": mean_balrog_score,
|
888
|
+
"balrog_total_rewards": balrog_total_rewards,
|
889
|
+
"mean_balrog_total_reward": mean_balrog_total_reward,
|
890
|
+
"mean_rubric": mean_rubric,
|
891
|
+
"progress_metrics": progress_metrics,
|
892
|
+
"mean_progress_metrics": mean_progress_metrics,
|
893
|
+
"shaped_rewards_summary": shaped_rewards_summary,
|
894
|
+
"irrelevant_summary": irrelevant_summary,
|
895
|
+
"irrelevant_shaped_summary": irrelevant_shaped_summary,
|
896
|
+
"relevant_shaped_rewards_sum": relevant_shaped_rewards_sum,
|
897
|
+
"individual_relevant_sums": individual_relevant_sums,
|
898
|
+
"individual_relevant_rubric_sums": individual_relevant_rubric_sums,
|
899
|
+
"relevant_rubric_sum": relevant_rubric_sum,
|
900
|
+
"num_episodes": len(valid_results),
|
901
|
+
}
|
902
|
+
|
903
|
+
|
904
|
+
async def main():
|
905
|
+
"""Run NetHack evaluation."""
|
906
|
+
print(f"š® NetHack ReAct Agent Evaluation")
|
907
|
+
print(f"Model: {MODEL_NAME}")
|
908
|
+
print(f"Service: {SERVICE_BASE_URL}")
|
909
|
+
print(f"Instances: {NUM_INSTANCES}")
|
910
|
+
print(f"Difficulty: {DIFFICULTY}")
|
911
|
+
print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
912
|
+
print("=" * 50)
|
913
|
+
|
914
|
+
# Test service health
|
915
|
+
async with AsyncClient(base_url=SERVICE_BASE_URL, timeout=10.0) as client:
|
916
|
+
try:
|
917
|
+
health_resp = await client.get("/health")
|
918
|
+
health_data = health_resp.json()
|
919
|
+
|
920
|
+
if "NetHack" not in health_data.get("supported_environments", []):
|
921
|
+
print("ā NetHack not available on service")
|
922
|
+
return
|
923
|
+
|
924
|
+
print("ā
Service health check passed")
|
925
|
+
|
926
|
+
except Exception as e:
|
927
|
+
print(f"ā Service health check failed: {e}")
|
928
|
+
return
|
929
|
+
|
930
|
+
# Run evaluation
|
931
|
+
try:
|
932
|
+
results = await evaluate_nethack_batch()
|
933
|
+
|
934
|
+
print("\n" + "=" * 80)
|
935
|
+
print("š FINAL NETHACK EVALUATION RESULTS")
|
936
|
+
print("=" * 80)
|
937
|
+
|
938
|
+
# Print eval metrics
|
939
|
+
print(f"š EVAL METRICS:")
|
940
|
+
print(f" Episodes: {results['num_episodes']}")
|
941
|
+
print(f" Individual Scores: {[f'{x:.2f}' for x in results['eval_metrics']]}")
|
942
|
+
print(f" Mean Eval Metric: {results['mean_eval_metric']:.2f}")
|
943
|
+
|
944
|
+
# Print Balrog scores
|
945
|
+
print(f"\nāļø BALROG SCORES:")
|
946
|
+
print(f" Individual Scores: {[f'{x:.3f}' for x in results['balrog_scores']]}")
|
947
|
+
print(f" Mean Balrog Score: {results['mean_balrog_score']:.3f}")
|
948
|
+
|
949
|
+
# Print Balrog total rewards
|
950
|
+
print(f"\nš BALROG TOTAL REWARDS:")
|
951
|
+
print(f" Individual Rewards: {[f'{x:.2f}' for x in results['balrog_total_rewards']]}")
|
952
|
+
print(f" Mean Balrog Total Reward: {results['mean_balrog_total_reward']:.2f}")
|
953
|
+
|
954
|
+
# Print relevant sums
|
955
|
+
print(f"\nšÆ RELEVANT RUBRIC SUMS:")
|
956
|
+
print(
|
957
|
+
f" Individual Sums: {[f'{x:.3f}' for x in results.get('individual_relevant_rubric_sums', [])]}"
|
958
|
+
)
|
959
|
+
print(f" Mean Relevant Rubric Sum: {results.get('relevant_rubric_sum', 0.0):.3f}")
|
960
|
+
|
961
|
+
print(f"\nšÆ RELEVANT SHAPED REWARD SUMS:")
|
962
|
+
print(
|
963
|
+
f" Individual Sums: {[f'{x:.3f}' for x in results.get('individual_relevant_sums', [])]}"
|
964
|
+
)
|
965
|
+
print(
|
966
|
+
f" Mean Relevant Shaped Reward Sum: {results.get('relevant_shaped_rewards_sum', 0.0):.3f}"
|
967
|
+
)
|
968
|
+
|
969
|
+
# Print rubric results
|
970
|
+
print(f"\nšÆ RUBRIC RESULTS:")
|
971
|
+
if results["mean_rubric"]:
|
972
|
+
for achievement, score in sorted(results["mean_rubric"].items()):
|
973
|
+
print(f" {achievement}: {score:.2f}")
|
974
|
+
else:
|
975
|
+
print(" No rubric data available")
|
976
|
+
|
977
|
+
# Print progress metrics
|
978
|
+
print(f"\nš PROGRESS METRICS:")
|
979
|
+
if results["mean_progress_metrics"]:
|
980
|
+
for metric, value in sorted(results["mean_progress_metrics"].items()):
|
981
|
+
print(f" {metric}: {value:.1f}")
|
982
|
+
else:
|
983
|
+
print(" No progress data available")
|
984
|
+
|
985
|
+
# Print shaped rewards summary
|
986
|
+
print(f"\nšÆ SHAPED REWARDS SUMMARY:")
|
987
|
+
if results.get("shaped_rewards_summary"):
|
988
|
+
for reward_key, value in sorted(results["shaped_rewards_summary"].items()):
|
989
|
+
if isinstance(value, bool):
|
990
|
+
print(f" {reward_key}: {value}")
|
991
|
+
else:
|
992
|
+
print(f" {reward_key}: {value:.3f}")
|
993
|
+
else:
|
994
|
+
print(" No shaped rewards data available")
|
995
|
+
|
996
|
+
# Print irrelevant shaped rewards
|
997
|
+
print(f"\nš« IRRELEVANT SHAPED REWARDS:")
|
998
|
+
if results.get("irrelevant_shaped_summary"):
|
999
|
+
for reward_key, value in sorted(results["irrelevant_shaped_summary"].items()):
|
1000
|
+
print(f" {reward_key}: {value:.3f}")
|
1001
|
+
else:
|
1002
|
+
print(" None")
|
1003
|
+
|
1004
|
+
# Print irrelevant rubric metrics
|
1005
|
+
print(f"\nš« IRRELEVANT RUBRIC METRICS:")
|
1006
|
+
if results.get("irrelevant_summary"):
|
1007
|
+
for metric, value in sorted(results["irrelevant_summary"].items()):
|
1008
|
+
print(f" {metric}: {value:.2f}")
|
1009
|
+
else:
|
1010
|
+
print(" None")
|
1011
|
+
|
1012
|
+
# Overall assessment
|
1013
|
+
print(f"\nš ASSESSMENT:")
|
1014
|
+
balrog_score = results["mean_balrog_score"]
|
1015
|
+
eval_metric = results["mean_eval_metric"]
|
1016
|
+
|
1017
|
+
if eval_metric > 0.8 or balrog_score > 40.0:
|
1018
|
+
print("š Excellent performance - mastering the dungeon!")
|
1019
|
+
elif eval_metric > 0.6 or balrog_score > 20.0:
|
1020
|
+
print("ā
Good performance - making solid progress!")
|
1021
|
+
elif eval_metric > 0.4 or balrog_score > 10.0:
|
1022
|
+
print("ā ļø Moderate performance - learning the ropes")
|
1023
|
+
elif balrog_score > 5.0:
|
1024
|
+
print("š Decent exploration - building dungeon skills")
|
1025
|
+
else:
|
1026
|
+
print("š Early exploration - focus on basic survival and movement")
|
1027
|
+
|
1028
|
+
# Output markdown table row for README collation
|
1029
|
+
print(f"\nš MARKDOWN TABLE ROW:")
|
1030
|
+
print(
|
1031
|
+
"| Model | Episodes | Mean Eval | Mean Balrog | Mean Relevant Rubric | Mean Relevant Shaped | Non-Zero Progress | Non-Zero Rubric | Assessment |"
|
1032
|
+
)
|
1033
|
+
print(
|
1034
|
+
"|------------------|----------|-----------|-------------|----------------------|----------------------|-------------------|-----------------|------------|"
|
1035
|
+
)
|
1036
|
+
relevant_rubric_sum = results.get("relevant_rubric_sum", 0.0)
|
1037
|
+
relevant_shaped_sum = results.get("relevant_shaped_rewards_sum", 0.0)
|
1038
|
+
|
1039
|
+
# Count non-zero progress metrics
|
1040
|
+
progress_metrics = results.get("mean_progress_metrics", {})
|
1041
|
+
non_zero_progress = sum(1 for value in progress_metrics.values() if value > 0.0)
|
1042
|
+
|
1043
|
+
# Count non-zero rubric results (excluding irrelevant ones)
|
1044
|
+
rubric_results = results.get("mean_rubric", {})
|
1045
|
+
non_zero_rubric = sum(
|
1046
|
+
1
|
1047
|
+
for key, value in rubric_results.items()
|
1048
|
+
if value > 0.0 and key not in IRRELEVANT_RUBRIC_KEYS
|
1049
|
+
)
|
1050
|
+
|
1051
|
+
if eval_metric > 0.6 or balrog_score > 20.0:
|
1052
|
+
assessment = "Excellent"
|
1053
|
+
elif eval_metric > 0.4 or balrog_score > 10.0:
|
1054
|
+
assessment = "Good"
|
1055
|
+
elif balrog_score > 5.0:
|
1056
|
+
assessment = "Moderate"
|
1057
|
+
else:
|
1058
|
+
assessment = "Learning"
|
1059
|
+
|
1060
|
+
print(
|
1061
|
+
f"| {MODEL_NAME:<16} | {results['num_episodes']:>8} | {eval_metric:>9.3f} | {balrog_score:>11.3f} | {relevant_rubric_sum:>20.3f} | {relevant_shaped_sum:>20.3f} | {non_zero_progress:>17} | {non_zero_rubric:>15} | {assessment:<10} |"
|
1062
|
+
)
|
1063
|
+
|
1064
|
+
except Exception as e:
|
1065
|
+
print(f"ā Evaluation failed: {e}")
|
1066
|
+
|
1067
|
+
|
1068
|
+
# Metrics that are considered baseline / always-positive and should be treated as irrelevant when summarizing
|
1069
|
+
IRRELEVANT_RUBRIC_KEYS = {
|
1070
|
+
"survival_reward",
|
1071
|
+
"hunger_management_reward",
|
1072
|
+
"damage_avoided_reward",
|
1073
|
+
"stairs_discovery_reward",
|
1074
|
+
"turn_alive_total", # from shaped summary
|
1075
|
+
"hunger_ok_total", # from shaped summary
|
1076
|
+
}
|
1077
|
+
|
1078
|
+
# --- CLI Entry Point ---
|
1079
|
+
if __name__ == "__main__":
|
1080
|
+
import argparse
|
1081
|
+
import asyncio
|
1082
|
+
|
1083
|
+
parser = argparse.ArgumentParser(
|
1084
|
+
description="Run NetHack ReAct Agent Evaluation (TOML configurable)"
|
1085
|
+
)
|
1086
|
+
parser.add_argument("--config", "-c", type=str, help="Path to TOML configuration file")
|
1087
|
+
parser.add_argument("--model", "-m", type=str, help="Model name (overrides config)")
|
1088
|
+
parser.add_argument("--episodes", "-e", type=int, help="Number of episodes (overrides config)")
|
1089
|
+
parser.add_argument("--max-turns", "-t", type=int, help="Maximum turns (overrides config)")
|
1090
|
+
parser.add_argument("--difficulty", "-d", type=str, help="Difficulty (overrides config)")
|
1091
|
+
|
1092
|
+
args = parser.parse_args()
|
1093
|
+
|
1094
|
+
if args.config:
|
1095
|
+
config = NetHackConfig(args.config)
|
1096
|
+
else:
|
1097
|
+
config = NetHackConfig()
|
1098
|
+
|
1099
|
+
# Apply CLI overrides
|
1100
|
+
if args.model:
|
1101
|
+
config.model_name = args.model
|
1102
|
+
if args.episodes:
|
1103
|
+
config.num_instances = args.episodes
|
1104
|
+
if args.max_turns:
|
1105
|
+
config.max_turns = args.max_turns
|
1106
|
+
if args.difficulty:
|
1107
|
+
config.difficulty = args.difficulty
|
1108
|
+
|
1109
|
+
_apply_config_to_globals(config)
|
1110
|
+
|
1111
|
+
# Run the evaluation
|
1112
|
+
asyncio.run(main())
|