synth-ai 0.1.9__py3-none-any.whl ā 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms ā lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms ā lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms ā lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms ā lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms ā lm}/config.py +2 -1
- synth_ai/{zyk/lms ā lm}/constants.py +2 -2
- synth_ai/{zyk/lms ā lm}/core/all.py +10 -10
- synth_ai/{zyk/lms ā lm}/core/main.py +57 -33
- synth_ai/{zyk/lms ā lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms ā lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms ā lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms ā lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms ā lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms ā lm}/vendors/core/gemini_api.py +37 -32
- synth_ai/{zyk/lms ā lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms ā lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms ā lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms ā lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms ā lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms ā lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms ā lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms ā lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms ā lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms ā lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms ā lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.1.9.dist-info/METADATA +0 -37
- synth_ai-0.1.9.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py ā environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching ā lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core ā lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost ā lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs ā lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors ā lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core ā lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local ā lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported ā lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms ā lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.1.9.dist-info ā synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.9.dist-info ā synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.9.dist-info ā synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,872 @@
|
|
1
|
+
import asyncio
|
2
|
+
import uuid
|
3
|
+
import pytest
|
4
|
+
import json
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Dict, Any, List, Optional, Deque, Set, Union
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
from collections import deque
|
9
|
+
import toml
|
10
|
+
from synth_ai.zyk import LM
|
11
|
+
from synth_ai.zyk.lms.tools.base import BaseTool
|
12
|
+
from synth_sdk.tracing.decorators import trace_event_async
|
13
|
+
from synth_sdk.tracing.trackers import SynthTracker
|
14
|
+
from synth_sdk.tracing.abstractions import RewardSignal, Dataset, TrainingQuestion
|
15
|
+
from synth_sdk.tracing.utils import get_system_id
|
16
|
+
|
17
|
+
# Crafter specific imports
|
18
|
+
from synth_ai.environments.examples.crafter_classic.environment import (
|
19
|
+
CrafterClassicEnvironment,
|
20
|
+
CrafterPublicState,
|
21
|
+
CrafterPrivateState,
|
22
|
+
)
|
23
|
+
from synth_ai.environments.examples.crafter_classic.engine import (
|
24
|
+
CRAFTER_ACTION_MAP, # map of action name to int
|
25
|
+
)
|
26
|
+
|
27
|
+
# Convert CRAFTER_ACTION_MAP to ACTION_STRING_TO_INT and INT_TO_ACTION_STRING
|
28
|
+
ACTION_STRING_TO_INT: Dict[str, int] = CRAFTER_ACTION_MAP
|
29
|
+
INT_TO_ACTION_STRING: Dict[int, str] = {v: k for k, v in CRAFTER_ACTION_MAP.items()}
|
30
|
+
|
31
|
+
|
32
|
+
from synth_ai.environments.environment.shared_engine import (
|
33
|
+
GetObservationCallable,
|
34
|
+
InternalObservation,
|
35
|
+
)
|
36
|
+
from synth_ai.environments.examples.crafter_classic.taskset import (
|
37
|
+
CrafterTaskInstance,
|
38
|
+
CrafterTaskInstanceMetadata,
|
39
|
+
)
|
40
|
+
from synth_ai.environments.tasks.core import Impetus, Intent
|
41
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
42
|
+
|
43
|
+
import logging
|
44
|
+
|
45
|
+
logging.disable(logging.CRITICAL)
|
46
|
+
|
47
|
+
|
48
|
+
# --- Helper to build crafter semantic mapping ---
|
49
|
+
def get_crafter_semantic_mapping():
|
50
|
+
"""Build the crafter semantic ID to item name mapping."""
|
51
|
+
import crafter
|
52
|
+
import itertools
|
53
|
+
|
54
|
+
# Create a dummy env to get ID mappings (same as environment.py)
|
55
|
+
dummyenv = None
|
56
|
+
try:
|
57
|
+
dummyenv = crafter.Env()
|
58
|
+
max_id = (
|
59
|
+
max(
|
60
|
+
max(dummyenv._world._mat_ids.values()),
|
61
|
+
max(dummyenv._sem_view._obj_ids.values()),
|
62
|
+
)
|
63
|
+
+ 1
|
64
|
+
)
|
65
|
+
id_to_item = ["void"] * max_id
|
66
|
+
for name, ind in itertools.chain(
|
67
|
+
dummyenv._world._mat_ids.items(), dummyenv._sem_view._obj_ids.items()
|
68
|
+
):
|
69
|
+
if name is None:
|
70
|
+
clean = "none"
|
71
|
+
elif hasattr(name, "__name__"):
|
72
|
+
clean = name.__name__
|
73
|
+
else:
|
74
|
+
clean = str(name)
|
75
|
+
id_to_item[ind] = clean.lower()
|
76
|
+
player_idx = id_to_item.index("player")
|
77
|
+
return id_to_item, player_idx
|
78
|
+
finally:
|
79
|
+
if dummyenv:
|
80
|
+
try:
|
81
|
+
dummyenv.close()
|
82
|
+
except Exception:
|
83
|
+
pass
|
84
|
+
del dummyenv
|
85
|
+
|
86
|
+
|
87
|
+
# --- Helper function to format observation for LLM ---
|
88
|
+
def format_obs_for_llm_from_states(pub: CrafterPublicState, priv: CrafterPrivateState) -> str:
|
89
|
+
inventory_str = ", ".join(f"{k}:{v}" for k, v in pub.inventory.items() if v > 0)
|
90
|
+
if not inventory_str:
|
91
|
+
inventory_str = "empty"
|
92
|
+
|
93
|
+
achievements_str = ", ".join(k for k, v in pub.achievements_status.items() if v)
|
94
|
+
if not achievements_str:
|
95
|
+
achievements_str = "none"
|
96
|
+
|
97
|
+
# Add map view around player using the real crafter semantic mapping
|
98
|
+
map_view = ""
|
99
|
+
if pub.semantic_map is not None:
|
100
|
+
px, py = pub.player_position
|
101
|
+
view_size = 7 # 7x7 view around player
|
102
|
+
half_view = view_size // 2
|
103
|
+
|
104
|
+
# Get the real crafter semantic mapping
|
105
|
+
id_to_item, player_idx = get_crafter_semantic_mapping()
|
106
|
+
|
107
|
+
# Create a local view around the player using same logic as _plain_grid
|
108
|
+
map_view += f"\nLocal Map View ({view_size}x{view_size} around player):\n"
|
109
|
+
matrix = []
|
110
|
+
for dy in range(-half_view, half_view + 1):
|
111
|
+
row = []
|
112
|
+
for dx in range(-half_view, half_view + 1):
|
113
|
+
x, y = px + dx, py + dy
|
114
|
+
if pub.semantic_map is None or not (
|
115
|
+
0 <= x < pub.semantic_map.shape[0] and 0 <= y < pub.semantic_map.shape[1]
|
116
|
+
):
|
117
|
+
row.append("void")
|
118
|
+
else:
|
119
|
+
idx = pub.semantic_map[x, y]
|
120
|
+
if dx == 0 and dy == 0:
|
121
|
+
row.append("player") # Player position
|
122
|
+
else:
|
123
|
+
# Use the real crafter mapping
|
124
|
+
item_name = id_to_item[idx] if idx < len(id_to_item) else "unknown"
|
125
|
+
row.append(item_name)
|
126
|
+
matrix.append(row)
|
127
|
+
|
128
|
+
# Transpose the matrix like _plain_grid does
|
129
|
+
transposed = list(zip(*matrix))
|
130
|
+
# Convert each row to a space-separated string
|
131
|
+
for row in transposed:
|
132
|
+
map_view += " ".join(row) + "\n"
|
133
|
+
|
134
|
+
# Create a legend of items actually visible in the map
|
135
|
+
visible_items = set()
|
136
|
+
for row in transposed:
|
137
|
+
for item in row:
|
138
|
+
if item not in ["void", "player"]:
|
139
|
+
visible_items.add(item)
|
140
|
+
|
141
|
+
if visible_items:
|
142
|
+
map_view += f"\nVisible items: {', '.join(sorted(visible_items))}"
|
143
|
+
else:
|
144
|
+
map_view += "\nNo special items visible (mostly grass/empty)"
|
145
|
+
|
146
|
+
# Simplified observation, focusing on key elements
|
147
|
+
return (
|
148
|
+
f"Steps: {pub.num_steps_taken}/{pub.max_steps_episode}\n"
|
149
|
+
f"Health: {priv.player_internal_stats.get('health', 'N/A')}\n"
|
150
|
+
f"Inventory: {inventory_str}\n"
|
151
|
+
f"Unlocked Achievements: {achievements_str}\n"
|
152
|
+
f"Player Position: {pub.player_position}\n"
|
153
|
+
f"Last Reward: {priv.reward_last_step:.2f}\n"
|
154
|
+
f"Terminated: {priv.terminated} | Truncated: {priv.truncated}"
|
155
|
+
f"{map_view}"
|
156
|
+
)
|
157
|
+
|
158
|
+
|
159
|
+
# ---------------------------------- custom observation callable (Optional, can be simpler for Crafter) ------------------------------ #
|
160
|
+
# For now, let's assume the default observation from the environment is sufficient,
|
161
|
+
# or we will use the direct public/private states.
|
162
|
+
# If history is needed, we can adapt the Sokoban HistoryObservationCallable.
|
163
|
+
class CrafterHistoryObservationCallable(GetObservationCallable):
|
164
|
+
def __init__(self, max_history: int = 1): # Keep only current obs for simplicity now
|
165
|
+
self._hist_obs: Deque[str] = deque(maxlen=max_history)
|
166
|
+
self._hist_pub_state: Deque[CrafterPublicState] = deque(maxlen=max_history)
|
167
|
+
self._hist_priv_state: Deque[CrafterPrivateState] = deque(maxlen=max_history)
|
168
|
+
|
169
|
+
async def get_observation(
|
170
|
+
self, pub: CrafterPublicState, priv: CrafterPrivateState
|
171
|
+
) -> InternalObservation:
|
172
|
+
if pub is None or priv is None:
|
173
|
+
return {
|
174
|
+
"error": "Missing public or private state in get_observation",
|
175
|
+
"history_formatted_obs": list(self._hist_obs),
|
176
|
+
} # type: ignore[return-value]
|
177
|
+
|
178
|
+
formatted_obs = format_obs_for_llm_from_states(pub, priv)
|
179
|
+
self._hist_obs.append(formatted_obs)
|
180
|
+
self._hist_pub_state.append(pub)
|
181
|
+
self._hist_priv_state.append(priv)
|
182
|
+
|
183
|
+
return {
|
184
|
+
"public": pub,
|
185
|
+
"private": priv,
|
186
|
+
"formatted_obs": formatted_obs, # Current formatted obs
|
187
|
+
"history_formatted_obs": list(self._hist_obs), # History of formatted obs
|
188
|
+
"history_public_states": list(self._hist_pub_state),
|
189
|
+
"history_private_states": list(self._hist_priv_state),
|
190
|
+
} # type: ignore[return-value]
|
191
|
+
|
192
|
+
|
193
|
+
# --- Pydantic Models for Tool Arguments ---
|
194
|
+
class CrafterInteractArgs(BaseModel):
|
195
|
+
actions_list: List[str] = Field(
|
196
|
+
description="A list of action names to execute in sequence in the Crafter environment (e.g., ['move_up', 'move_up', 'place_stone']). Can contain 1-10 actions."
|
197
|
+
)
|
198
|
+
reasoning: str = Field(description="A brief explanation of why these actions were chosen.")
|
199
|
+
|
200
|
+
|
201
|
+
# class TerminateArgs(BaseModel):
|
202
|
+
# reason: str = Field(
|
203
|
+
# description="A detailed reason for why the agent is terminating."
|
204
|
+
# )
|
205
|
+
|
206
|
+
|
207
|
+
# --- ReAct agent for Crafter -------------------------------------------------- #
|
208
|
+
class CrafterInteractTool(BaseTool):
|
209
|
+
"""Tool for interacting with Crafter environment"""
|
210
|
+
|
211
|
+
name: str = "crafter_interact"
|
212
|
+
arguments: type[BaseModel] = CrafterInteractArgs
|
213
|
+
description: str = (
|
214
|
+
"Interacts with the Crafter environment by proposing a sequence of 1-10 actions to execute."
|
215
|
+
)
|
216
|
+
|
217
|
+
|
218
|
+
# class TerminateTool(BaseTool):
|
219
|
+
# """Tool for terminating agent execution"""
|
220
|
+
# name: str = "terminate"
|
221
|
+
# arguments: type[BaseModel] = TerminateArgs
|
222
|
+
# description: str = "Terminates the agent's execution if the task is considered complete or no useful progress can be made."
|
223
|
+
|
224
|
+
|
225
|
+
class CrafterMove(EnvToolCall): # Simple EnvToolCall wrapper
|
226
|
+
def __init__(self, action: int):
|
227
|
+
super().__init__(tool="interact", args={"action": action})
|
228
|
+
|
229
|
+
|
230
|
+
class ReActAgent:
|
231
|
+
def __init__(self, llm, max_turns: int = 50): # Increased max_turns for Crafter
|
232
|
+
self.llm, self.max_turns = llm, max_turns
|
233
|
+
self.history: List[Dict[str, Any]] = []
|
234
|
+
self.system_name: str = "crafter-react-ex" # Changed system name
|
235
|
+
self.system_id: Any = get_system_id(self.system_name)
|
236
|
+
self.system_instance_id: str = str(uuid.uuid4())
|
237
|
+
self.last_obs_dict: Optional[Dict[str, Any]] = (
|
238
|
+
None # To store raw observation for terminate guardrails
|
239
|
+
)
|
240
|
+
self.current_achievements: Set[str] = set() # To track unique achievements
|
241
|
+
|
242
|
+
self.tools = [
|
243
|
+
CrafterInteractTool(),
|
244
|
+
# TerminateTool(), # Commented out to prevent early quitting
|
245
|
+
]
|
246
|
+
|
247
|
+
def _format_history_for_prompt(self) -> str:
|
248
|
+
prompt_history = []
|
249
|
+
for entry in self.history:
|
250
|
+
if entry["type"] == "obs":
|
251
|
+
prompt_history.append(f"OBSERVATION:\n{entry['content']}")
|
252
|
+
elif entry["type"] == "tool_call":
|
253
|
+
args_str = json.dumps(entry["tool_arguments"])
|
254
|
+
prompt_history.append(
|
255
|
+
f"THOUGHT:\nI will call the tool `{entry['tool_name']}` with arguments: {args_str}\nACTION: (Tool call executed)"
|
256
|
+
)
|
257
|
+
elif entry["type"] == "tool_response":
|
258
|
+
prompt_history.append(
|
259
|
+
"TOOL_RESPONSE:\n(Action executed, new observation will follow if not terminal)"
|
260
|
+
)
|
261
|
+
return "\n".join(prompt_history)
|
262
|
+
|
263
|
+
@trace_event_async(event_type="react_agent_decide")
|
264
|
+
async def decide(
|
265
|
+
self, obs_str: str, current_raw_obs: Dict[str, Any]
|
266
|
+
) -> List[int]: # Return list of action integers
|
267
|
+
self.history.append({"type": "obs", "content": obs_str})
|
268
|
+
self.last_obs_dict = current_raw_obs # Store for terminate guardrail
|
269
|
+
|
270
|
+
# Update current achievements from the raw observation
|
271
|
+
if current_raw_obs and isinstance(current_raw_obs.get("public"), CrafterPublicState):
|
272
|
+
pub_state: CrafterPublicState = current_raw_obs["public"]
|
273
|
+
for ach, unlocked in pub_state.achievements_status.items():
|
274
|
+
if unlocked:
|
275
|
+
self.current_achievements.add(ach)
|
276
|
+
|
277
|
+
formatted_prompt_history = self._format_history_for_prompt()
|
278
|
+
|
279
|
+
# Updated prompt for Crafter
|
280
|
+
prompt = (
|
281
|
+
f"{formatted_prompt_history}\n\n"
|
282
|
+
"Based on the history above, particularly the last observation (health, inventory, achievements, position), "
|
283
|
+
"what is your reasoning and which `crafter_interact` tool should you call next? "
|
284
|
+
"Prioritize actions that lead to new achievements or ensure survival (e.g., find food if health is low)."
|
285
|
+
)
|
286
|
+
|
287
|
+
system_message = (
|
288
|
+
"You are an agent playing Crafter. Your goal is to survive and unlock as many achievements as possible. "
|
289
|
+
"Review the history of observations, thoughts, and actions. "
|
290
|
+
"Based on this history, particularly the last observation, decide on the best sequence of actions. "
|
291
|
+
"You MUST call the available tool: `crafter_interact`.\\n\\n"
|
292
|
+
"For `crafter_interact`, provide a list of 1-10 actions to execute in sequence. "
|
293
|
+
"Planning ahead with multiple actions is often more efficient than single actions. "
|
294
|
+
f"Available actions are: {', '.join(ACTION_STRING_TO_INT.keys())}.\\n"
|
295
|
+
"Always provide a `reasoning` field in your tool call."
|
296
|
+
)
|
297
|
+
|
298
|
+
# Trace the LLM interaction input so that full messages (system & user) are included in the trace
|
299
|
+
SynthTracker.track_lm(
|
300
|
+
messages=[
|
301
|
+
{"role": "system", "content": system_message},
|
302
|
+
{"role": "user", "content": prompt},
|
303
|
+
],
|
304
|
+
model_name=self.llm.model_name,
|
305
|
+
model_params=None,
|
306
|
+
finetune=False,
|
307
|
+
)
|
308
|
+
|
309
|
+
response_obj = await self.llm.respond_async(
|
310
|
+
system_message=system_message, user_message=prompt, tools=self.tools
|
311
|
+
)
|
312
|
+
|
313
|
+
# Trace the assistant's reply/output so that it is captured alongside the inputs
|
314
|
+
SynthTracker.track_lm_output(
|
315
|
+
messages=[{"role": "assistant", "content": response_obj.raw_response}],
|
316
|
+
model_name=self.llm.model_name,
|
317
|
+
finetune=False,
|
318
|
+
)
|
319
|
+
|
320
|
+
tool_calls = response_obj.tool_calls
|
321
|
+
|
322
|
+
# Handle case where tool_calls is None or empty (noop to prevent crash)
|
323
|
+
if not tool_calls:
|
324
|
+
# print(f"[WARNING] No tool calls returned by {self.llm.model_name}, returning noop action")
|
325
|
+
self.history.append(
|
326
|
+
{
|
327
|
+
"type": "tool_call",
|
328
|
+
"tool_name": "noop",
|
329
|
+
"tool_arguments": {"reason": "no_tool_calls_returned"},
|
330
|
+
}
|
331
|
+
)
|
332
|
+
self.history.append(
|
333
|
+
{
|
334
|
+
"type": "tool_response",
|
335
|
+
"content": "Noop executed due to missing tool calls",
|
336
|
+
}
|
337
|
+
)
|
338
|
+
return [0] # Return 'noop' action (action index 0)
|
339
|
+
|
340
|
+
tool_call_data = tool_calls[0]
|
341
|
+
|
342
|
+
# Handle both dict and object formats
|
343
|
+
if isinstance(tool_call_data, dict):
|
344
|
+
tool_name = tool_call_data["function"]["name"]
|
345
|
+
tool_args_str = tool_call_data["function"]["arguments"]
|
346
|
+
else:
|
347
|
+
tool_name = tool_call_data.function.name
|
348
|
+
tool_args_str = tool_call_data.function.arguments
|
349
|
+
|
350
|
+
tool_arguments = json.loads(tool_args_str)
|
351
|
+
|
352
|
+
# Track the tool call details for richer debugging and training signals
|
353
|
+
SynthTracker.track_state(
|
354
|
+
variable_name="tool_call",
|
355
|
+
variable_value={"tool_name": tool_name, "arguments": tool_arguments},
|
356
|
+
origin="agent",
|
357
|
+
)
|
358
|
+
|
359
|
+
self.history.append(
|
360
|
+
{
|
361
|
+
"type": "tool_call",
|
362
|
+
"tool_name": tool_name,
|
363
|
+
"tool_arguments": tool_arguments,
|
364
|
+
}
|
365
|
+
)
|
366
|
+
self.history.append({"type": "tool_response", "content": "Tool executed"})
|
367
|
+
|
368
|
+
if tool_name == "crafter_interact":
|
369
|
+
actions_list = tool_arguments["actions_list"]
|
370
|
+
|
371
|
+
# Convert action names to integers
|
372
|
+
action_ints = []
|
373
|
+
for action_str in actions_list:
|
374
|
+
if action_str in ACTION_STRING_TO_INT:
|
375
|
+
action_ints.append(ACTION_STRING_TO_INT[action_str])
|
376
|
+
else:
|
377
|
+
print(f"[WARNING] Invalid action '{action_str}', using noop instead")
|
378
|
+
action_ints.append(0) # noop action
|
379
|
+
|
380
|
+
return action_ints
|
381
|
+
|
382
|
+
# elif tool_name == "terminate":
|
383
|
+
# reason = tool_arguments["reason"]
|
384
|
+
#
|
385
|
+
# # Add the human-readable termination reason to the history
|
386
|
+
# self.history.append({
|
387
|
+
# "type": "termination",
|
388
|
+
# "content": f"Agent terminated: {reason}",
|
389
|
+
# "reason": reason
|
390
|
+
# })
|
391
|
+
#
|
392
|
+
# return [-1] # Special termination indicator
|
393
|
+
|
394
|
+
|
395
|
+
# --- Test for a single agent run ---
|
396
|
+
@pytest.mark.asyncio
|
397
|
+
async def test_react_agent_crafter(tmp_path: Path):
|
398
|
+
# Create a simple Crafter task instance for testing
|
399
|
+
# For Crafter, the seed in metadata is important for reproducibility.
|
400
|
+
# initial_engine_snapshot can be None if the engine handles reset with seed.
|
401
|
+
task_metadata = CrafterTaskInstanceMetadata(
|
402
|
+
difficulty="easy",
|
403
|
+
seed=42,
|
404
|
+
# Other metadata fields can be default or placeholders if not critical for this test
|
405
|
+
num_trees_radius=0, # Placeholder, actual values depend on seed and world gen
|
406
|
+
num_cows_radius=0, # Placeholder
|
407
|
+
num_hostiles_radius=0, # Placeholder
|
408
|
+
)
|
409
|
+
inst = CrafterTaskInstance(
|
410
|
+
id=uuid.uuid4(),
|
411
|
+
impetus=Impetus(instructions="Survive and unlock achievements."),
|
412
|
+
intent=Intent(
|
413
|
+
rubric={"goal": "Unlock achievements and survive"},
|
414
|
+
gold_trajectories=None,
|
415
|
+
gold_state_diff={},
|
416
|
+
),
|
417
|
+
metadata=task_metadata,
|
418
|
+
is_reproducible=True,
|
419
|
+
initial_engine_snapshot=None, # Engine will init with seed from metadata
|
420
|
+
)
|
421
|
+
|
422
|
+
hist_cb = CrafterHistoryObservationCallable(max_history=1)
|
423
|
+
env = CrafterClassicEnvironment(inst, custom_step_obs=hist_cb)
|
424
|
+
# env.engine.package_sokoban_env.render_mode = "raw" # Not applicable to Crafter
|
425
|
+
|
426
|
+
llm = LM(model_name="gpt-4.1-nano", formatting_model_name="gpt-4.1-nano", temperature=0.0)
|
427
|
+
agent = ReActAgent(llm, max_turns=30) # Increased for meaningful progress
|
428
|
+
print("[DEBUG] Created agent with max_turns=30")
|
429
|
+
|
430
|
+
async def run_episode():
|
431
|
+
obs_payload = await env.initialize()
|
432
|
+
|
433
|
+
if "error" in obs_payload:
|
434
|
+
print(f"Error during env.initialize: {obs_payload['error']}")
|
435
|
+
return False, 0
|
436
|
+
|
437
|
+
# Initial observation for the agent
|
438
|
+
# The CrafterHistoryObservationCallable returns a dict with 'public', 'private', 'formatted_obs'
|
439
|
+
current_formatted_obs = obs_payload["formatted_obs"]
|
440
|
+
raw_obs_for_agent_decision = (
|
441
|
+
obs_payload # Pass the whole payload which includes public and private states
|
442
|
+
)
|
443
|
+
|
444
|
+
for turn in range(agent.max_turns):
|
445
|
+
action_sequence = await agent.decide(current_formatted_obs, raw_obs_for_agent_decision)
|
446
|
+
|
447
|
+
if action_sequence == [-1]: # Agent decided to terminate
|
448
|
+
obs_payload_next = obs_payload # No new observation if terminated by agent
|
449
|
+
break
|
450
|
+
|
451
|
+
# Execute each action in the sequence
|
452
|
+
for act_idx in action_sequence:
|
453
|
+
step_result = await env.step([[CrafterMove(act_idx)]])
|
454
|
+
obs_payload_next = step_result
|
455
|
+
|
456
|
+
if "error" in obs_payload_next:
|
457
|
+
break
|
458
|
+
|
459
|
+
# Update observation for next action in sequence
|
460
|
+
current_formatted_obs = obs_payload_next["formatted_obs"]
|
461
|
+
raw_obs_for_agent_decision = obs_payload_next
|
462
|
+
obs_payload = obs_payload_next
|
463
|
+
|
464
|
+
# Check if environment terminated after this sub-action
|
465
|
+
if obs_payload_next["private"].terminated or obs_payload_next["private"].truncated:
|
466
|
+
priv_state = obs_payload_next["private"]
|
467
|
+
pub_state = obs_payload_next["public"]
|
468
|
+
player_health = priv_state.player_internal_stats.get("health", "N/A")
|
469
|
+
steps_taken = pub_state.num_steps_taken
|
470
|
+
max_steps = pub_state.max_steps_episode
|
471
|
+
|
472
|
+
break
|
473
|
+
|
474
|
+
if "error" in obs_payload_next:
|
475
|
+
break
|
476
|
+
|
477
|
+
# Update observations for the next agent decision
|
478
|
+
current_formatted_obs = obs_payload_next["formatted_obs"]
|
479
|
+
raw_obs_for_agent_decision = obs_payload_next
|
480
|
+
|
481
|
+
agent.history.append({"type": "tool_response", "content": "Action executed"})
|
482
|
+
|
483
|
+
obs_payload = obs_payload_next
|
484
|
+
|
485
|
+
if obs_payload_next["private"].terminated or obs_payload_next["private"].truncated:
|
486
|
+
break
|
487
|
+
|
488
|
+
# Ensure obs_payload_next is defined even if loop didn't run or agent terminated early
|
489
|
+
if "obs_payload_next" not in locals():
|
490
|
+
obs_payload_next = obs_payload
|
491
|
+
|
492
|
+
if "error" in obs_payload_next:
|
493
|
+
return False, len(agent.current_achievements)
|
494
|
+
|
495
|
+
# Success could be defined as surviving some steps or achieving something
|
496
|
+
# For this test, let's say it's successful if it ran and terminated/truncated by env
|
497
|
+
final_private_state: CrafterPrivateState = obs_payload_next["private"]
|
498
|
+
episode_successful = final_private_state.terminated or final_private_state.truncated
|
499
|
+
return episode_successful, len(agent.current_achievements)
|
500
|
+
|
501
|
+
episode_completed, num_achievements = await run_episode()
|
502
|
+
|
503
|
+
dataset = Dataset(
|
504
|
+
questions=[
|
505
|
+
TrainingQuestion(
|
506
|
+
id="crafter_ep_test",
|
507
|
+
intent="survive and achieve",
|
508
|
+
criteria="completed_episode_or_achieved_something",
|
509
|
+
)
|
510
|
+
],
|
511
|
+
reward_signals=[
|
512
|
+
RewardSignal(
|
513
|
+
question_id="crafter_ep_test",
|
514
|
+
system_instance_id=agent.system_instance_id,
|
515
|
+
reward=1
|
516
|
+
if episode_completed or num_achievements > 0
|
517
|
+
else 0, # Reward if completed or got any achievement
|
518
|
+
)
|
519
|
+
],
|
520
|
+
)
|
521
|
+
# upload(dataset=dataset) # Optional: uncomment to upload trace
|
522
|
+
|
523
|
+
assert episode_completed or num_achievements > 0, (
|
524
|
+
"Agent failed to complete the episode or unlock any achievement in the test."
|
525
|
+
)
|
526
|
+
|
527
|
+
|
528
|
+
async def eval_react_crafter(
|
529
|
+
model_name: str = "gpt-4.1-nano",
|
530
|
+
formatting_model_name: str = "gpt-4.1-nano",
|
531
|
+
modes: Optional[List[str]] = None,
|
532
|
+
n_instances_per_mode: int = 3,
|
533
|
+
) -> List[Dict[str, Any]]:
|
534
|
+
"""
|
535
|
+
Run ReAct agents on Crafter instances of different difficulties,
|
536
|
+
and returns a list of dictionaries containing aggregated results for each mode.
|
537
|
+
"""
|
538
|
+
# Import the new evaluation framework
|
539
|
+
from synth_ai.environments.examples.crafter_classic.agent_demos.eval_framework import (
|
540
|
+
run_crafter_eval,
|
541
|
+
)
|
542
|
+
|
543
|
+
if modes is None:
|
544
|
+
modes = ["easy", "hard"]
|
545
|
+
|
546
|
+
print(f"šÆ Running Crafter evaluation with new standardized framework")
|
547
|
+
print(f" Model: {model_name}")
|
548
|
+
print(f" Modes: {modes}")
|
549
|
+
print(f" Trajectories per mode: {n_instances_per_mode}")
|
550
|
+
|
551
|
+
# Use the new comprehensive evaluation framework
|
552
|
+
report = await run_crafter_eval(
|
553
|
+
model_names=[model_name],
|
554
|
+
difficulties=modes,
|
555
|
+
num_trajectories=n_instances_per_mode,
|
556
|
+
max_turns=30,
|
557
|
+
)
|
558
|
+
|
559
|
+
# Convert to old format for backward compatibility
|
560
|
+
results_for_model = []
|
561
|
+
for agg_result in report["raw_aggregate_results"]:
|
562
|
+
results_for_model.append(
|
563
|
+
{
|
564
|
+
"Model": agg_result["model_name"],
|
565
|
+
"Difficulty": agg_result["difficulty"],
|
566
|
+
"Successful Runs": f"{int(agg_result['success_rate'] * agg_result['num_trajectories'])}/{agg_result['num_trajectories']}",
|
567
|
+
"Avg Unique Achievements": f"{agg_result['avg_achievements_per_trajectory']:.2f}",
|
568
|
+
}
|
569
|
+
)
|
570
|
+
|
571
|
+
return results_for_model
|
572
|
+
|
573
|
+
|
574
|
+
# Keep the old function for backward compatibility
|
575
|
+
async def eval_react_crafter_legacy(
|
576
|
+
model_name: str = "gpt-4.1-nano",
|
577
|
+
formatting_model_name: str = "gpt-4.1-nano",
|
578
|
+
modes: Optional[List[str]] = None,
|
579
|
+
n_instances_per_mode: int = 3,
|
580
|
+
) -> List[Dict[str, Any]]:
|
581
|
+
"""
|
582
|
+
LEGACY VERSION - Run ReAct agents on Crafter instances of different difficulties,
|
583
|
+
and returns a list of dictionaries containing aggregated results for each mode.
|
584
|
+
"""
|
585
|
+
|
586
|
+
if modes is None:
|
587
|
+
modes = ["easy", "hard"]
|
588
|
+
|
589
|
+
current_model_name_for_eval = model_name
|
590
|
+
|
591
|
+
_temp_llm_for_names = LM(
|
592
|
+
model_name=current_model_name_for_eval,
|
593
|
+
formatting_model_name=formatting_model_name,
|
594
|
+
temperature=0.0,
|
595
|
+
)
|
596
|
+
_temp_agent_for_names = ReActAgent(_temp_llm_for_names)
|
597
|
+
actual_system_name = (
|
598
|
+
_temp_agent_for_names.system_name
|
599
|
+
) # Still useful for logging within this func
|
600
|
+
|
601
|
+
# ------------------------------------------------------------------ helpers
|
602
|
+
async def run_episode_eval(
|
603
|
+
inst: CrafterTaskInstance, agent_max_turns: int
|
604
|
+
) -> tuple[bool, int, list[str], int]: # Added achievements list and steps taken
|
605
|
+
"""Run single episode and return (success, num_achievements, achievements_list, steps_taken)"""
|
606
|
+
llm = LM(
|
607
|
+
model_name=current_model_name_for_eval,
|
608
|
+
formatting_model_name=current_model_name_for_eval,
|
609
|
+
temperature=0.0,
|
610
|
+
)
|
611
|
+
|
612
|
+
hist_cb = CrafterHistoryObservationCallable(max_history=1)
|
613
|
+
env = CrafterClassicEnvironment(inst, custom_step_obs=hist_cb)
|
614
|
+
agent = ReActAgent(llm, max_turns=agent_max_turns)
|
615
|
+
|
616
|
+
obs_payload = await env.initialize()
|
617
|
+
if "error" in obs_payload:
|
618
|
+
return False, 0, [], 0
|
619
|
+
|
620
|
+
current_formatted_obs = obs_payload["formatted_obs"]
|
621
|
+
raw_obs_for_agent_decision = obs_payload
|
622
|
+
|
623
|
+
turn_count = 0
|
624
|
+
for turn_idx in range(agent.max_turns):
|
625
|
+
turn_count += 1
|
626
|
+
# Remove noisy progress output
|
627
|
+
|
628
|
+
action_sequence = await agent.decide(current_formatted_obs, raw_obs_for_agent_decision)
|
629
|
+
|
630
|
+
if action_sequence == [-1]: # agent terminated
|
631
|
+
break
|
632
|
+
|
633
|
+
# Execute each action in the sequence
|
634
|
+
for i, act_idx in enumerate(action_sequence):
|
635
|
+
obs_payload_next = await env.step([[CrafterMove(act_idx)]])
|
636
|
+
|
637
|
+
if "error" in obs_payload_next:
|
638
|
+
break # Break out of action sequence on error
|
639
|
+
|
640
|
+
# Update observation for next action in sequence
|
641
|
+
current_formatted_obs = obs_payload_next["formatted_obs"]
|
642
|
+
raw_obs_for_agent_decision = obs_payload_next
|
643
|
+
|
644
|
+
# Check if environment terminated after this sub-action
|
645
|
+
if obs_payload_next["private"].terminated or obs_payload_next["private"].truncated:
|
646
|
+
break
|
647
|
+
|
648
|
+
if "error" in obs_payload_next:
|
649
|
+
return (
|
650
|
+
False,
|
651
|
+
len(agent.current_achievements),
|
652
|
+
list(agent.current_achievements),
|
653
|
+
0,
|
654
|
+
)
|
655
|
+
|
656
|
+
current_formatted_obs = obs_payload_next["formatted_obs"]
|
657
|
+
raw_obs_for_agent_decision = obs_payload_next
|
658
|
+
agent.history.append({"type": "tool_response", "content": "Action executed"})
|
659
|
+
|
660
|
+
obs_payload = obs_payload_next
|
661
|
+
if obs_payload["private"].terminated or obs_payload["private"].truncated:
|
662
|
+
break
|
663
|
+
|
664
|
+
final_private_state: CrafterPrivateState = obs_payload["private"]
|
665
|
+
final_public_state: CrafterPublicState = obs_payload["public"]
|
666
|
+
|
667
|
+
run_successful = (final_private_state.terminated or final_private_state.truncated) or len(
|
668
|
+
agent.current_achievements
|
669
|
+
) >= 1
|
670
|
+
|
671
|
+
num_unique_achievements = len(agent.current_achievements)
|
672
|
+
achievements_list = list(agent.current_achievements)
|
673
|
+
steps_taken = final_public_state.num_steps_taken
|
674
|
+
|
675
|
+
return run_successful, num_unique_achievements, achievements_list, steps_taken
|
676
|
+
|
677
|
+
# ---------------------------------------------------------------- instance factory
|
678
|
+
async def make_crafter_instances(
|
679
|
+
difficulty: str, n_instances: int = 3, start_seed: int = 0
|
680
|
+
) -> List[CrafterTaskInstance]:
|
681
|
+
instances = []
|
682
|
+
for i in range(n_instances):
|
683
|
+
current_seed = start_seed + i
|
684
|
+
metadata = CrafterTaskInstanceMetadata(
|
685
|
+
difficulty=difficulty,
|
686
|
+
seed=current_seed,
|
687
|
+
num_trees_radius=0,
|
688
|
+
num_cows_radius=0,
|
689
|
+
num_hostiles_radius=0,
|
690
|
+
)
|
691
|
+
instance = CrafterTaskInstance(
|
692
|
+
id=uuid.uuid4(),
|
693
|
+
impetus=Impetus(
|
694
|
+
instructions=f"Survive and unlock achievements in a {difficulty} environment."
|
695
|
+
),
|
696
|
+
intent=Intent(rubric={}, gold_trajectories=None, gold_state_diff={}),
|
697
|
+
metadata=metadata,
|
698
|
+
is_reproducible=True,
|
699
|
+
initial_engine_snapshot=None,
|
700
|
+
)
|
701
|
+
instances.append(instance)
|
702
|
+
return instances
|
703
|
+
|
704
|
+
# ---------------------------------------------------------------- evaluation
|
705
|
+
configs = []
|
706
|
+
for mode in modes:
|
707
|
+
if mode == "easy":
|
708
|
+
configs.append(("easy", n_instances_per_mode, 15))
|
709
|
+
elif mode == "hard":
|
710
|
+
configs.append(("hard", n_instances_per_mode, 15))
|
711
|
+
|
712
|
+
results_for_model = [] # Stores dicts for each mode for the current model
|
713
|
+
base_seed_for_difficulty = {"easy": 1000, "hard": 2000}
|
714
|
+
|
715
|
+
print(
|
716
|
+
f"Starting Crafter ReAct Agent Evaluation for Model: {current_model_name_for_eval}, System: {actual_system_name}"
|
717
|
+
)
|
718
|
+
|
719
|
+
all_generated_task_data = []
|
720
|
+
print("\nGenerating task instances...")
|
721
|
+
all_tasks_for_eval: Dict[str, List[CrafterTaskInstance]] = {}
|
722
|
+
for label, num_agents, _ in configs:
|
723
|
+
insts = await make_crafter_instances(
|
724
|
+
label, n_instances=num_agents, start_seed=base_seed_for_difficulty[label]
|
725
|
+
)
|
726
|
+
all_tasks_for_eval[label] = insts
|
727
|
+
for inst in insts:
|
728
|
+
instance_dict = await inst.serialize()
|
729
|
+
all_generated_task_data.append(instance_dict)
|
730
|
+
print(f"Generated {len(insts)} instances for {label} difficulty.")
|
731
|
+
|
732
|
+
dataset_dir = Path(__file__).parent.parent / "dataset"
|
733
|
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
734
|
+
synthetic_mix_path = dataset_dir / "synthetic_mix.json"
|
735
|
+
with open(synthetic_mix_path, "w") as f:
|
736
|
+
json.dump(all_generated_task_data, f, indent=2)
|
737
|
+
print(
|
738
|
+
f"Saved all {len(all_generated_task_data)} generated task instances to {synthetic_mix_path}"
|
739
|
+
)
|
740
|
+
|
741
|
+
for label, num_agents, max_episode_turns in configs:
|
742
|
+
print(
|
743
|
+
f"\nRunning {num_agents} agents on {label} difficulty tasks (max_turns: {max_episode_turns}) for model {current_model_name_for_eval}..."
|
744
|
+
)
|
745
|
+
current_difficulty_instances = all_tasks_for_eval[label]
|
746
|
+
|
747
|
+
import time
|
748
|
+
|
749
|
+
start_time = time.time()
|
750
|
+
results_per_episode = await asyncio.gather(
|
751
|
+
*(run_episode_eval(inst, max_episode_turns) for inst in current_difficulty_instances)
|
752
|
+
)
|
753
|
+
end_time = time.time()
|
754
|
+
print(
|
755
|
+
f"Completed {len(current_difficulty_instances)} episodes in {end_time - start_time:.1f}s"
|
756
|
+
)
|
757
|
+
|
758
|
+
# Process detailed results
|
759
|
+
successful_episodes = 0
|
760
|
+
total_achievements = 0
|
761
|
+
detailed_results = []
|
762
|
+
|
763
|
+
for i, (success, num_achievements, achievements_list, steps_taken) in enumerate(
|
764
|
+
results_per_episode
|
765
|
+
):
|
766
|
+
episode_result = {
|
767
|
+
"episode_id": i + 1,
|
768
|
+
"instance_id": current_difficulty_instances[i].id,
|
769
|
+
"success": success,
|
770
|
+
"achievements_count": num_achievements,
|
771
|
+
"achievements": achievements_list,
|
772
|
+
"steps_taken": steps_taken,
|
773
|
+
"turns_used": "unknown", # Could track this if needed
|
774
|
+
}
|
775
|
+
detailed_results.append(episode_result)
|
776
|
+
|
777
|
+
if success:
|
778
|
+
successful_episodes += 1
|
779
|
+
total_achievements += num_achievements
|
780
|
+
|
781
|
+
avg_achievements = total_achievements / len(results_per_episode)
|
782
|
+
|
783
|
+
# Print detailed trajectory information
|
784
|
+
print(f"\nš Detailed Results for {model_name} on {label}:")
|
785
|
+
print("-" * 80)
|
786
|
+
for result in detailed_results:
|
787
|
+
status = "ā
SUCCESS" if result["success"] else "ā FAILED"
|
788
|
+
achievements_str = (
|
789
|
+
", ".join(result["achievements"]) if result["achievements"] else "None"
|
790
|
+
)
|
791
|
+
print(
|
792
|
+
f"Episode {result['episode_id']}: {status} | "
|
793
|
+
f"Steps: {result['steps_taken']} | "
|
794
|
+
f"Achievements ({result['achievements_count']}): {achievements_str}"
|
795
|
+
)
|
796
|
+
print("-" * 80)
|
797
|
+
|
798
|
+
print(
|
799
|
+
f"Completed {label} for model {model_name}: {successful_episodes}/{len(results_per_episode)} successful, Avg. Achievements: {avg_achievements:.2f}"
|
800
|
+
)
|
801
|
+
|
802
|
+
results_for_model.append(
|
803
|
+
{
|
804
|
+
"Model": model_name,
|
805
|
+
"Difficulty": label,
|
806
|
+
"Successful Runs": f"{successful_episodes}/{len(results_per_episode)}",
|
807
|
+
"Avg Unique Achievements": f"{avg_achievements:.2f}",
|
808
|
+
}
|
809
|
+
)
|
810
|
+
|
811
|
+
return results_for_model
|
812
|
+
|
813
|
+
|
814
|
+
async def run_model_comparison_from_config():
|
815
|
+
"""Run model comparison using parameters from eval_config.toml"""
|
816
|
+
# Load configuration
|
817
|
+
config_path = Path(__file__).parent / "eval_config.toml"
|
818
|
+
if not config_path.exists():
|
819
|
+
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
820
|
+
|
821
|
+
config = toml.load(config_path)
|
822
|
+
eval_config = config["evaluation"]
|
823
|
+
|
824
|
+
models = eval_config["models"]
|
825
|
+
difficulties = eval_config["difficulties"]
|
826
|
+
max_turns = eval_config["max_turns"]
|
827
|
+
n_trajectories = eval_config["trajectories_per_condition"]
|
828
|
+
|
829
|
+
# Update global max_turns from config
|
830
|
+
global agent_max_turns
|
831
|
+
agent_max_turns = max_turns
|
832
|
+
|
833
|
+
print("šÆ Crafter Multi-Action Model Comparison")
|
834
|
+
print("=" * 50)
|
835
|
+
print(f"Models: {', '.join(models)}")
|
836
|
+
print(f"Difficulties: {', '.join(difficulties)}")
|
837
|
+
print(f"Max turns: {max_turns}")
|
838
|
+
print(f"Trajectories per condition: {n_trajectories}")
|
839
|
+
print("=" * 50)
|
840
|
+
|
841
|
+
all_results = []
|
842
|
+
|
843
|
+
for model_name in models:
|
844
|
+
print(f"\nš¤ Running {model_name}...")
|
845
|
+
|
846
|
+
# Update the global variable for the model
|
847
|
+
global current_model_name_for_eval
|
848
|
+
current_model_name_for_eval = model_name
|
849
|
+
|
850
|
+
model_results = await eval_react_crafter_legacy(
|
851
|
+
model_name=model_name,
|
852
|
+
formatting_model_name=model_name,
|
853
|
+
modes=difficulties,
|
854
|
+
n_instances_per_mode=n_trajectories,
|
855
|
+
)
|
856
|
+
|
857
|
+
all_results.extend(model_results)
|
858
|
+
print(f"ā
{model_name} completed")
|
859
|
+
|
860
|
+
print("\n" + "=" * 60)
|
861
|
+
print("š FINAL COMPARISON RESULTS")
|
862
|
+
print("=" * 60)
|
863
|
+
|
864
|
+
from tabulate import tabulate
|
865
|
+
|
866
|
+
print(tabulate(all_results, headers="keys", tablefmt="github"))
|
867
|
+
|
868
|
+
return all_results
|
869
|
+
|
870
|
+
|
871
|
+
if __name__ == "__main__":
|
872
|
+
asyncio.run(run_model_comparison_from_config())
|