synth-ai 0.2.4.dev7__py3-none-any.whl ā 0.2.4.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +1 -1
- synth_ai/cli/__init__.py +6 -0
- synth_ai/cli/balance.py +3 -15
- synth_ai/cli/demo.py +68 -9
- synth_ai/cli/rl_demo.py +137 -0
- synth_ai/cli/root.py +65 -0
- synth_ai/config/base_url.py +47 -0
- synth_ai/demos/core/__init__.py +1 -0
- synth_ai/demos/core/cli.py +621 -0
- synth_ai/demos/demo_task_apps/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/core.py +374 -0
- synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/math/app.py +37 -0
- synth_ai/demos/demo_task_apps/math/config.toml +44 -0
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
- synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
- synth_ai/environments/examples/bandit/__init__.py +33 -0
- synth_ai/environments/examples/bandit/engine.py +294 -0
- synth_ai/environments/examples/bandit/environment.py +194 -0
- synth_ai/environments/examples/bandit/taskset.py +200 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
- synth_ai/environments/examples/crafter_classic/environment.py +41 -2
- synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
- synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
- synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
- synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
- synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/http.py +102 -0
- synth_ai/inference/__init__.py +7 -0
- synth_ai/inference/client.py +20 -0
- synth_ai/install_sqld.sh +40 -0
- synth_ai/jobs/client.py +246 -0
- synth_ai/learning/__init__.py +24 -0
- synth_ai/learning/client.py +149 -0
- synth_ai/learning/config.py +43 -0
- synth_ai/learning/constants.py +29 -0
- synth_ai/learning/ft_client.py +59 -0
- synth_ai/learning/health.py +43 -0
- synth_ai/learning/jobs.py +205 -0
- synth_ai/learning/rl_client.py +256 -0
- synth_ai/learning/sse.py +58 -0
- synth_ai/learning/validators.py +48 -0
- synth_ai/lm/core/main_v3.py +13 -0
- synth_ai/lm/core/synth_models.py +48 -0
- synth_ai/lm/core/vendor_clients.py +9 -6
- synth_ai/lm/vendors/core/openai_api.py +31 -3
- synth_ai/lm/vendors/openai_standard.py +45 -14
- synth_ai/lm/vendors/supported/custom_endpoint.py +12 -2
- synth_ai/lm/vendors/synth_client.py +372 -28
- synth_ai/rl/__init__.py +30 -0
- synth_ai/rl/contracts.py +32 -0
- synth_ai/rl/env_keys.py +137 -0
- synth_ai/rl/secrets.py +19 -0
- synth_ai/scripts/verify_rewards.py +100 -0
- synth_ai/task/__init__.py +10 -0
- synth_ai/task/contracts.py +120 -0
- synth_ai/task/health.py +28 -0
- synth_ai/task/validators.py +12 -0
- synth_ai/tracing_v3/hooks.py +3 -1
- synth_ai/tracing_v3/session_tracer.py +123 -2
- synth_ai/tracing_v3/turso/manager.py +218 -0
- synth_ai/tracing_v3/turso/models.py +53 -0
- synth_ai-0.2.4.dev9.dist-info/METADATA +91 -0
- {synth_ai-0.2.4.dev7.dist-info ā synth_ai-0.2.4.dev9.dist-info}/RECORD +147 -30
- {synth_ai-0.2.4.dev7.dist-info ā synth_ai-0.2.4.dev9.dist-info}/entry_points.txt +1 -0
- synth_ai/tui/__init__.py +0 -1
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -340
- synth_ai-0.2.4.dev7.dist-info/METADATA +0 -193
- {synth_ai-0.2.4.dev7.dist-info ā synth_ai-0.2.4.dev9.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev7.dist-info ā synth_ai-0.2.4.dev9.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev7.dist-info ā synth_ai-0.2.4.dev9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,858 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Comprehensive script to run Crafter rollouts for multiple models and compare their performance.
|
|
4
|
+
Updated to use tracing_v3 with async architecture.
|
|
5
|
+
|
|
6
|
+
Runs experiments for:
|
|
7
|
+
- gpt-4o-mini
|
|
8
|
+
- gpt-4.1-mini
|
|
9
|
+
- gpt-4.1-nano
|
|
10
|
+
- gemini-1.5-flash
|
|
11
|
+
- gemini-2.5-flash-lite
|
|
12
|
+
- qwen3/32b
|
|
13
|
+
|
|
14
|
+
Analyzes and compares:
|
|
15
|
+
- Invalid action rates
|
|
16
|
+
- Achievement frequencies by step
|
|
17
|
+
- Achievement counts across models
|
|
18
|
+
- Performance metrics
|
|
19
|
+
- Cost analysis
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import argparse
|
|
23
|
+
import asyncio
|
|
24
|
+
import json
|
|
25
|
+
import logging
|
|
26
|
+
import os
|
|
27
|
+
import sys
|
|
28
|
+
import time
|
|
29
|
+
from collections import defaultdict
|
|
30
|
+
from datetime import datetime
|
|
31
|
+
from pathlib import Path
|
|
32
|
+
from typing import Any
|
|
33
|
+
from uuid import uuid4
|
|
34
|
+
|
|
35
|
+
import numpy as np
|
|
36
|
+
import pandas as pd
|
|
37
|
+
from tqdm import tqdm
|
|
38
|
+
from tqdm.asyncio import tqdm_asyncio as atqdm
|
|
39
|
+
|
|
40
|
+
# Disable httpx logging for cleaner output
|
|
41
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
42
|
+
|
|
43
|
+
# Add parent directory to path for imports
|
|
44
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent.parent))
|
|
45
|
+
|
|
46
|
+
# Disable v1 logging to see v3 tracing clearly
|
|
47
|
+
os.environ["LANGFUSE_ENABLED"] = "false"
|
|
48
|
+
os.environ["SYNTH_LOGGING"] = "false"
|
|
49
|
+
|
|
50
|
+
# Import enhanced LM with v3 tracing
|
|
51
|
+
from synth_ai.lm.core.main_v3 import LM
|
|
52
|
+
from synth_ai.tracing_v3.abstractions import (
|
|
53
|
+
EnvironmentEvent,
|
|
54
|
+
RuntimeEvent,
|
|
55
|
+
SessionEventMarkovBlanketMessage,
|
|
56
|
+
TimeRecord,
|
|
57
|
+
)
|
|
58
|
+
from synth_ai.tracing_v3.decorators import set_turn_number
|
|
59
|
+
|
|
60
|
+
# Import session tracer for v3 tracing
|
|
61
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
62
|
+
|
|
63
|
+
# from synth_ai.tracing_v3.utils import create_experiment_context # Not needed
|
|
64
|
+
from synth_ai.tracing_v3.turso.manager import AsyncSQLTraceManager
|
|
65
|
+
|
|
66
|
+
# Import Crafter hooks
|
|
67
|
+
try:
|
|
68
|
+
from synth_ai.environments.examples.crafter_classic.trace_hooks_v3 import CRAFTER_HOOKS
|
|
69
|
+
print(f"ā
Loaded {len(CRAFTER_HOOKS.hooks)} Crafter achievement hooks (Easy, Medium, Hard)")
|
|
70
|
+
except ImportError:
|
|
71
|
+
print("Warning: Could not import CRAFTER_HOOKS for v3")
|
|
72
|
+
from synth_ai.tracing_v3.hooks import HookManager
|
|
73
|
+
CRAFTER_HOOKS = HookManager()
|
|
74
|
+
|
|
75
|
+
import random
|
|
76
|
+
|
|
77
|
+
import httpx
|
|
78
|
+
|
|
79
|
+
# Global buckets for sessions
|
|
80
|
+
_SESSIONS: dict[str, tuple[str, object]] = {} # session_id -> (experiment_id, trace)
|
|
81
|
+
|
|
82
|
+
# Configuration
|
|
83
|
+
MODELS_TO_TEST = [
|
|
84
|
+
"gpt-4o-mini",
|
|
85
|
+
"gpt-4.1-mini",
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
# Service URLs (modify these based on your setup)
|
|
89
|
+
CRAFTER_SERVICE_URL = "http://localhost:8901"
|
|
90
|
+
|
|
91
|
+
# Database configuration - uses the centralized config which matches serve.sh
|
|
92
|
+
from synth_ai.tracing_v3.db_config import get_default_db_config
|
|
93
|
+
|
|
94
|
+
db_config = get_default_db_config()
|
|
95
|
+
DATABASE_URL = db_config.database_url
|
|
96
|
+
|
|
97
|
+
# Retry configuration for HTTP requests
|
|
98
|
+
MAX_RETRIES = 3
|
|
99
|
+
BASE_DELAY = 0.1
|
|
100
|
+
MAX_DELAY = 2.0
|
|
101
|
+
HTTP_TIMEOUT = 30.0
|
|
102
|
+
|
|
103
|
+
class ExperimentConfig:
|
|
104
|
+
"""Configuration for the multi-model experiment."""
|
|
105
|
+
|
|
106
|
+
def __init__(self):
|
|
107
|
+
self.num_episodes = 10 # Number of episodes per model
|
|
108
|
+
self.max_turns = 100 # Max turns per episode
|
|
109
|
+
self.difficulty = "easy"
|
|
110
|
+
self.save_traces = True
|
|
111
|
+
self.verbose = True
|
|
112
|
+
self.quiet = False # Default to verbose mode
|
|
113
|
+
self.enable_v3_tracing = True
|
|
114
|
+
self.v3_trace_dir = "./traces"
|
|
115
|
+
self.crafter_service_url = CRAFTER_SERVICE_URL
|
|
116
|
+
self.database_url = DATABASE_URL
|
|
117
|
+
self.base_seed = 1000 # Base seed for episode generation
|
|
118
|
+
self.turn_timeout = 30.0 # Timeout per turn in seconds
|
|
119
|
+
self.episode_timeout = 300.0 # Total timeout per episode in seconds
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
async def retry_http_request(client: httpx.AsyncClient, method: str, url: str, **kwargs) -> Any:
|
|
123
|
+
"""Retry HTTP requests with exponential backoff and jitter."""
|
|
124
|
+
last_exception = None
|
|
125
|
+
|
|
126
|
+
for attempt in range(MAX_RETRIES):
|
|
127
|
+
try:
|
|
128
|
+
if attempt > 0:
|
|
129
|
+
delay = min(BASE_DELAY * (2 ** (attempt - 1)), MAX_DELAY)
|
|
130
|
+
jitter = random.uniform(0, 0.1 * delay)
|
|
131
|
+
total_delay = delay + jitter
|
|
132
|
+
await asyncio.sleep(total_delay)
|
|
133
|
+
|
|
134
|
+
response = await client.request(method, url, timeout=HTTP_TIMEOUT, **kwargs)
|
|
135
|
+
|
|
136
|
+
if response.status_code < 500:
|
|
137
|
+
return response
|
|
138
|
+
|
|
139
|
+
last_exception = Exception(f"HTTP {response.status_code}: {response.text}")
|
|
140
|
+
|
|
141
|
+
except httpx.ConnectError as e:
|
|
142
|
+
last_exception = Exception(f"Connection failed to {url}: {e}")
|
|
143
|
+
if attempt < MAX_RETRIES - 1:
|
|
144
|
+
await asyncio.sleep(1.0 * (2 ** attempt))
|
|
145
|
+
except httpx.ReadError as e:
|
|
146
|
+
last_exception = e
|
|
147
|
+
if attempt < MAX_RETRIES - 1:
|
|
148
|
+
read_error_delay = min(1.0 * (2 ** attempt), 5.0)
|
|
149
|
+
await asyncio.sleep(read_error_delay)
|
|
150
|
+
except Exception as e:
|
|
151
|
+
last_exception = e
|
|
152
|
+
|
|
153
|
+
print(f" ā HTTP request failed after {MAX_RETRIES} attempts: {method} {url}")
|
|
154
|
+
print(f" ā Error: {type(last_exception).__name__}: {str(last_exception)[:200]}")
|
|
155
|
+
raise last_exception
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# Crafter action mapping
|
|
159
|
+
CRAFTER_ACTIONS = {
|
|
160
|
+
"noop": 0, "move_left": 1, "move_right": 2, "move_up": 3, "move_down": 4,
|
|
161
|
+
"do": 5, "sleep": 6, "place_stone": 7, "place_table": 8, "place_furnace": 9,
|
|
162
|
+
"place_plant": 10, "make_wood_pickaxe": 11, "make_stone_pickaxe": 12,
|
|
163
|
+
"make_iron_pickaxe": 13, "make_wood_sword": 14, "make_stone_sword": 15,
|
|
164
|
+
"make_iron_sword": 16, "eat_cow": 17, "eat_plant": 18
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
# Create reverse mapping for validation
|
|
168
|
+
INT_TO_ACTION_STRING = {v: k for k, v in CRAFTER_ACTIONS.items()}
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def compress_observation_for_trace(obs: dict[str, Any]) -> str:
|
|
172
|
+
"""Compress observation data for storage in traces."""
|
|
173
|
+
try:
|
|
174
|
+
return json.dumps({
|
|
175
|
+
"inv": {k: v for k, v in obs.get("inventory", {}).items() if v > 0},
|
|
176
|
+
"nearby": obs.get("nearby", []),
|
|
177
|
+
"hp": obs.get("status", {}).get("health", 0),
|
|
178
|
+
"food": obs.get("status", {}).get("food", 0),
|
|
179
|
+
"ach": sum(1 for v in obs.get("achievements_status", {}).values() if v)
|
|
180
|
+
}, separators=(',', ':'))
|
|
181
|
+
except Exception as e:
|
|
182
|
+
return f"{{\"error\": \"{str(e)}\"}}"
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def create_message(content: str, message_type: str, system_id: str, turn: int) -> SessionEventMarkovBlanketMessage:
|
|
186
|
+
"""Create a SessionEventMarkovBlanketMessage with metadata."""
|
|
187
|
+
return SessionEventMarkovBlanketMessage(
|
|
188
|
+
content=content,
|
|
189
|
+
message_type=message_type,
|
|
190
|
+
metadata={"system_id": system_id, "turn": turn},
|
|
191
|
+
time_record=TimeRecord(
|
|
192
|
+
event_time=time.time(),
|
|
193
|
+
message_time=turn
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
async def run_episode(config: ExperimentConfig,
|
|
199
|
+
model_name: str,
|
|
200
|
+
episode_num: int,
|
|
201
|
+
experiment_id: str) -> dict[str, Any]:
|
|
202
|
+
"""Run a single episode with a specific model using v3 tracing."""
|
|
203
|
+
# Create a new session tracer for this episode
|
|
204
|
+
session_tracer = SessionTracer(hooks=CRAFTER_HOOKS, db_url=config.database_url)
|
|
205
|
+
|
|
206
|
+
# Start session with metadata
|
|
207
|
+
session_id = await session_tracer.start_session(
|
|
208
|
+
metadata={
|
|
209
|
+
"model": model_name,
|
|
210
|
+
"episode": episode_num,
|
|
211
|
+
"experiment_id": experiment_id,
|
|
212
|
+
"difficulty": config.difficulty
|
|
213
|
+
}
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Started tracing session (output disabled for clean UI)
|
|
217
|
+
|
|
218
|
+
# Store session in global bucket
|
|
219
|
+
_SESSIONS[session_id] = (experiment_id, session_tracer)
|
|
220
|
+
|
|
221
|
+
# Initialize LM with session tracer
|
|
222
|
+
lm = LM(
|
|
223
|
+
vendor="openai",
|
|
224
|
+
model=model_name,
|
|
225
|
+
temperature=0.1, # Low temperature for more consistent gameplay
|
|
226
|
+
session_tracer=session_tracer,
|
|
227
|
+
system_id=f"crafter_agent_{model_name}",
|
|
228
|
+
enable_v3_tracing=True
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Create HTTP client
|
|
232
|
+
async with httpx.AsyncClient() as client:
|
|
233
|
+
try:
|
|
234
|
+
# Initialize environment with consecutive seed
|
|
235
|
+
seed = config.base_seed + episode_num # Base seed + episode number for consecutive seeds
|
|
236
|
+
request_data = {"config": {"difficulty": config.difficulty, "seed": seed}}
|
|
237
|
+
init_response = await retry_http_request(
|
|
238
|
+
client, "POST", f"{config.crafter_service_url}/env/CrafterClassic/initialize",
|
|
239
|
+
json=request_data
|
|
240
|
+
)
|
|
241
|
+
init_data = init_response.json()
|
|
242
|
+
|
|
243
|
+
# Debug the response format (removed for clean output)
|
|
244
|
+
|
|
245
|
+
# Handle different possible response formats
|
|
246
|
+
if "instance_id" in init_data:
|
|
247
|
+
instance_id = init_data["instance_id"]
|
|
248
|
+
elif "env_id" in init_data:
|
|
249
|
+
instance_id = init_data["env_id"]
|
|
250
|
+
elif "id" in init_data:
|
|
251
|
+
instance_id = init_data["id"]
|
|
252
|
+
else:
|
|
253
|
+
# If none of the expected keys exist, print the response and raise a clear error
|
|
254
|
+
print(f"ā Unexpected response format from Crafter service: {init_data}")
|
|
255
|
+
raise KeyError(f"Could not find environment ID in response. Available keys: {list(init_data.keys())}")
|
|
256
|
+
|
|
257
|
+
# Get initial observation (from initialize response)
|
|
258
|
+
obs = init_data["observation"]
|
|
259
|
+
|
|
260
|
+
prev_obs = obs
|
|
261
|
+
done = False
|
|
262
|
+
invalid_actions = 0
|
|
263
|
+
total_actions = 0
|
|
264
|
+
episode_start_time = time.time()
|
|
265
|
+
|
|
266
|
+
for turn in range(config.max_turns):
|
|
267
|
+
if done:
|
|
268
|
+
break
|
|
269
|
+
|
|
270
|
+
# Check episode timeout
|
|
271
|
+
if time.time() - episode_start_time > config.episode_timeout:
|
|
272
|
+
print(f" ā° Episode {episode_num} timed out after {config.episode_timeout}s")
|
|
273
|
+
done = True
|
|
274
|
+
break
|
|
275
|
+
|
|
276
|
+
# Update progress bar
|
|
277
|
+
if hasattr(config, '_pbar'):
|
|
278
|
+
current_achievements = sum(1 for v in obs.get("achievements_status", {}).values() if v)
|
|
279
|
+
config._pbar.set_postfix({
|
|
280
|
+
f"ep{episode_num}": f"step {turn+1}/{config.max_turns}, ach: {current_achievements}"
|
|
281
|
+
})
|
|
282
|
+
|
|
283
|
+
set_turn_number(turn)
|
|
284
|
+
|
|
285
|
+
# Start timestep for this turn
|
|
286
|
+
await session_tracer.start_timestep(f"turn_{turn}")
|
|
287
|
+
|
|
288
|
+
# Prepare context for the agent
|
|
289
|
+
inventory_str = ", ".join([f"{k}: {v}" for k, v in obs.get("inventory", {}).items() if v > 0])
|
|
290
|
+
if not inventory_str:
|
|
291
|
+
inventory_str = "empty"
|
|
292
|
+
|
|
293
|
+
nearby_str = ", ".join(obs.get("nearby", []))
|
|
294
|
+
if not nearby_str:
|
|
295
|
+
nearby_str = "nothing"
|
|
296
|
+
|
|
297
|
+
status = obs.get("status", {})
|
|
298
|
+
health = status.get("health", 0)
|
|
299
|
+
hunger = status.get("food", 0)
|
|
300
|
+
|
|
301
|
+
# Get more detailed game state
|
|
302
|
+
position = obs.get("position", [0, 0])
|
|
303
|
+
achievements = obs.get("achievements_status", {})
|
|
304
|
+
unlocked = [name for name, status in achievements.items() if status]
|
|
305
|
+
achievements_str = ", ".join(unlocked) if unlocked else "none"
|
|
306
|
+
|
|
307
|
+
# Get semantic map if available
|
|
308
|
+
semantic_map = obs.get("semantic_map", None)
|
|
309
|
+
map_str = ""
|
|
310
|
+
if semantic_map is not None:
|
|
311
|
+
# Simple 5x5 view around player
|
|
312
|
+
try:
|
|
313
|
+
px, py = position
|
|
314
|
+
view_size = 5
|
|
315
|
+
half = view_size // 2
|
|
316
|
+
map_lines = []
|
|
317
|
+
for dy in range(-half, half + 1):
|
|
318
|
+
row = []
|
|
319
|
+
for dx in range(-half, half + 1):
|
|
320
|
+
x, y = px + dx, py + dy
|
|
321
|
+
if dx == 0 and dy == 0:
|
|
322
|
+
row.append("@") # Player
|
|
323
|
+
elif 0 <= x < len(semantic_map) and 0 <= y < len(semantic_map[0]):
|
|
324
|
+
cell = semantic_map[x][y]
|
|
325
|
+
# Map common items
|
|
326
|
+
if cell == 0:
|
|
327
|
+
row.append(".") # Empty/grass
|
|
328
|
+
elif cell == 1:
|
|
329
|
+
row.append("T") # Tree
|
|
330
|
+
elif cell == 2:
|
|
331
|
+
row.append("S") # Stone
|
|
332
|
+
elif cell == 3:
|
|
333
|
+
row.append("C") # Cow
|
|
334
|
+
elif cell == 4:
|
|
335
|
+
row.append("W") # Water
|
|
336
|
+
else:
|
|
337
|
+
row.append("?")
|
|
338
|
+
else:
|
|
339
|
+
row.append("#") # Out of bounds
|
|
340
|
+
map_lines.append(" ".join(row))
|
|
341
|
+
map_str = "\nMap (5x5 view, @ = you):\n" + "\n".join(map_lines)
|
|
342
|
+
except Exception:
|
|
343
|
+
map_str = "\nMap view unavailable"
|
|
344
|
+
|
|
345
|
+
# Create agent prompt
|
|
346
|
+
prompt = f"""Game State (Turn {turn}):
|
|
347
|
+
- Position: {position}
|
|
348
|
+
- Health: {health}/9
|
|
349
|
+
- Hunger: {hunger}/9
|
|
350
|
+
- Inventory: {inventory_str}
|
|
351
|
+
- Nearby objects: {nearby_str}
|
|
352
|
+
- Achievements unlocked: {achievements_str}
|
|
353
|
+
{map_str}
|
|
354
|
+
|
|
355
|
+
Choose your next actions based on what you see. Use the 'interact' tool with a list of action IDs.
|
|
356
|
+
|
|
357
|
+
Tips:
|
|
358
|
+
- Look at the map! T=tree (wood), S=stone, C=cow (food), W=water
|
|
359
|
+
- To collect resources: move to them (actions 1-4) then use action 5 (do)
|
|
360
|
+
- To craft: place table (8) first, then craft tools (11-16)
|
|
361
|
+
- If hungry and see cow (C), move to it and eat (17)
|
|
362
|
+
|
|
363
|
+
What actions do you want to take?"""
|
|
364
|
+
|
|
365
|
+
# Send observation as message
|
|
366
|
+
obs_msg = create_message(
|
|
367
|
+
f"Observation: {compress_observation_for_trace(obs)}",
|
|
368
|
+
"system",
|
|
369
|
+
f"crafter_env_{instance_id}",
|
|
370
|
+
turn
|
|
371
|
+
)
|
|
372
|
+
await session_tracer.record_message(
|
|
373
|
+
content=obs_msg.content,
|
|
374
|
+
message_type=obs_msg.message_type,
|
|
375
|
+
event_time=obs_msg.time_record.event_time,
|
|
376
|
+
message_time=obs_msg.time_record.message_time,
|
|
377
|
+
metadata=obs_msg.metadata
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
# Get action from LM with tools (with timeout)
|
|
381
|
+
turn_start_time = time.time()
|
|
382
|
+
try:
|
|
383
|
+
# Define the interact tool for Crafter
|
|
384
|
+
from pydantic import BaseModel, Field
|
|
385
|
+
from synth_ai.lm.tools.base import BaseTool
|
|
386
|
+
|
|
387
|
+
class InteractArgs(BaseModel):
|
|
388
|
+
actions: list[int] = Field(..., description="List of action IDs to execute")
|
|
389
|
+
|
|
390
|
+
interact_tool = BaseTool(
|
|
391
|
+
name="interact",
|
|
392
|
+
arguments=InteractArgs,
|
|
393
|
+
description="Execute actions in the Crafter game"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
# Create system message that explains available actions
|
|
397
|
+
action_list = "\n".join([f"{action_id}: {action}" for action, action_id in CRAFTER_ACTIONS.items()])
|
|
398
|
+
system_message = f"""You are an agent playing Crafter, a 2D survival game. Your goal is to survive and unlock achievements.
|
|
399
|
+
|
|
400
|
+
You MUST use the 'interact' tool to execute actions. The tool takes a list of action IDs.
|
|
401
|
+
|
|
402
|
+
Action ID mapping:
|
|
403
|
+
{action_list}
|
|
404
|
+
|
|
405
|
+
Strategy tips:
|
|
406
|
+
- Start by collecting wood (move to trees and use action 5)
|
|
407
|
+
- Place a crafting table (action 8) to unlock crafting recipes
|
|
408
|
+
- Craft tools to collect resources more efficiently
|
|
409
|
+
- Eat when hungry, sleep when tired
|
|
410
|
+
- Explore to find different resources
|
|
411
|
+
|
|
412
|
+
IMPORTANT: Always use the 'interact' tool with a list of action IDs. For example: interact(actions=[2, 2, 5]) to move right twice and collect."""
|
|
413
|
+
|
|
414
|
+
# Get actions from LM using tools with timeout
|
|
415
|
+
try:
|
|
416
|
+
action_response = await asyncio.wait_for(
|
|
417
|
+
lm.respond_async(
|
|
418
|
+
system_message=system_message,
|
|
419
|
+
user_message=prompt,
|
|
420
|
+
tools=[interact_tool],
|
|
421
|
+
turn_number=turn
|
|
422
|
+
),
|
|
423
|
+
timeout=config.turn_timeout
|
|
424
|
+
)
|
|
425
|
+
except asyncio.TimeoutError:
|
|
426
|
+
print(f" ā° Turn {turn} timed out for episode {episode_num} after {config.turn_timeout}s")
|
|
427
|
+
action_response = None
|
|
428
|
+
done = True
|
|
429
|
+
break
|
|
430
|
+
|
|
431
|
+
# Debug: print response (removed for clean output)
|
|
432
|
+
|
|
433
|
+
# Extract tool calls from response
|
|
434
|
+
if hasattr(action_response, 'tool_calls') and action_response.tool_calls:
|
|
435
|
+
tool_calls = action_response.tool_calls
|
|
436
|
+
|
|
437
|
+
# Process each tool call
|
|
438
|
+
for tool_call in tool_calls:
|
|
439
|
+
if tool_call.get('function', {}).get('name') == 'interact':
|
|
440
|
+
# Extract actions from the tool call
|
|
441
|
+
import json
|
|
442
|
+
args = json.loads(tool_call.get('function', {}).get('arguments', '{}'))
|
|
443
|
+
actions = args.get('actions', [])
|
|
444
|
+
|
|
445
|
+
if not actions:
|
|
446
|
+
# If no actions provided, use noop
|
|
447
|
+
actions = [0]
|
|
448
|
+
|
|
449
|
+
# Execute each action separately
|
|
450
|
+
for action_id in actions:
|
|
451
|
+
total_actions += 1
|
|
452
|
+
|
|
453
|
+
# Validate action ID
|
|
454
|
+
if action_id not in INT_TO_ACTION_STRING:
|
|
455
|
+
# Invalid action logging removed for clean output
|
|
456
|
+
action_id = 0
|
|
457
|
+
invalid_actions += 1
|
|
458
|
+
|
|
459
|
+
# Send action to Crafter service with timeout
|
|
460
|
+
try:
|
|
461
|
+
step_response = await asyncio.wait_for(
|
|
462
|
+
retry_http_request(
|
|
463
|
+
client, "POST", f"{config.crafter_service_url}/env/CrafterClassic/step",
|
|
464
|
+
json={
|
|
465
|
+
"env_id": instance_id,
|
|
466
|
+
"action": {
|
|
467
|
+
"tool_calls": [
|
|
468
|
+
{"tool": "interact", "args": {"action": action_id}}
|
|
469
|
+
]
|
|
470
|
+
}
|
|
471
|
+
}
|
|
472
|
+
),
|
|
473
|
+
timeout=5.0 # 5 second timeout for individual action
|
|
474
|
+
)
|
|
475
|
+
except asyncio.TimeoutError:
|
|
476
|
+
print(f" ā° Action execution timed out in episode {episode_num}")
|
|
477
|
+
done = True
|
|
478
|
+
break
|
|
479
|
+
|
|
480
|
+
if step_response.status_code != 200:
|
|
481
|
+
print(f" ā Step failed: {step_response.status_code} - {step_response.text}")
|
|
482
|
+
done = True
|
|
483
|
+
break
|
|
484
|
+
|
|
485
|
+
step_data = step_response.json()
|
|
486
|
+
|
|
487
|
+
# Extract data from response
|
|
488
|
+
new_obs = step_data["observation"]
|
|
489
|
+
reward = step_data["reward"]
|
|
490
|
+
done = step_data["done"]
|
|
491
|
+
|
|
492
|
+
# Record runtime event for action
|
|
493
|
+
action_name = INT_TO_ACTION_STRING.get(action_id, "unknown")
|
|
494
|
+
runtime_event = RuntimeEvent(
|
|
495
|
+
system_instance_id=f"crafter_env_{instance_id}",
|
|
496
|
+
time_record=TimeRecord(
|
|
497
|
+
event_time=time.time(),
|
|
498
|
+
message_time=turn
|
|
499
|
+
),
|
|
500
|
+
actions=[action_id],
|
|
501
|
+
metadata={
|
|
502
|
+
"action_name": action_name,
|
|
503
|
+
"valid": action_name != "noop" or invalid_actions == 0
|
|
504
|
+
}
|
|
505
|
+
)
|
|
506
|
+
await session_tracer.record_event(runtime_event)
|
|
507
|
+
|
|
508
|
+
# Record environment event
|
|
509
|
+
env_event = EnvironmentEvent(
|
|
510
|
+
system_instance_id=f"crafter_env_{instance_id}",
|
|
511
|
+
time_record=TimeRecord(
|
|
512
|
+
event_time=time.time(),
|
|
513
|
+
message_time=turn
|
|
514
|
+
),
|
|
515
|
+
reward=reward,
|
|
516
|
+
terminated=done,
|
|
517
|
+
system_state_before={"observation": prev_obs},
|
|
518
|
+
system_state_after={"observation": new_obs, "public_state": {"achievements_status": new_obs.get("achievements_status", {})}}
|
|
519
|
+
)
|
|
520
|
+
await session_tracer.record_event(env_event)
|
|
521
|
+
|
|
522
|
+
# Update for next turn
|
|
523
|
+
prev_obs = obs
|
|
524
|
+
obs = new_obs
|
|
525
|
+
|
|
526
|
+
if done:
|
|
527
|
+
break
|
|
528
|
+
|
|
529
|
+
# Update progress bar after each action
|
|
530
|
+
if hasattr(config, '_pbar'):
|
|
531
|
+
config._pbar.update(1)
|
|
532
|
+
else:
|
|
533
|
+
# No tool calls provided, use noop
|
|
534
|
+
action_id = 0
|
|
535
|
+
total_actions += 1
|
|
536
|
+
invalid_actions += 1
|
|
537
|
+
|
|
538
|
+
# Send noop action with timeout
|
|
539
|
+
try:
|
|
540
|
+
step_response = await asyncio.wait_for(
|
|
541
|
+
retry_http_request(
|
|
542
|
+
client, "POST", f"{config.crafter_service_url}/env/CrafterClassic/step",
|
|
543
|
+
json={
|
|
544
|
+
"env_id": instance_id,
|
|
545
|
+
"action": {
|
|
546
|
+
"tool_calls": [
|
|
547
|
+
{"tool": "interact", "args": {"action": action_id}}
|
|
548
|
+
]
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
),
|
|
552
|
+
timeout=5.0 # 5 second timeout
|
|
553
|
+
)
|
|
554
|
+
except asyncio.TimeoutError:
|
|
555
|
+
print(f" ā° Noop action timed out in episode {episode_num}")
|
|
556
|
+
done = True
|
|
557
|
+
break
|
|
558
|
+
|
|
559
|
+
if step_response.status_code != 200:
|
|
560
|
+
print(f" ā Step failed: {step_response.status_code} - {step_response.text}")
|
|
561
|
+
done = True
|
|
562
|
+
else:
|
|
563
|
+
step_data = step_response.json()
|
|
564
|
+
new_obs = step_data["observation"]
|
|
565
|
+
reward = step_data["reward"]
|
|
566
|
+
done = step_data["done"]
|
|
567
|
+
|
|
568
|
+
# Update observation
|
|
569
|
+
prev_obs = obs
|
|
570
|
+
obs = new_obs
|
|
571
|
+
|
|
572
|
+
# End timestep
|
|
573
|
+
await session_tracer.end_timestep(f"turn_{turn}")
|
|
574
|
+
|
|
575
|
+
except Exception as e:
|
|
576
|
+
print(f" ā Environment step error: {e}")
|
|
577
|
+
done = True
|
|
578
|
+
|
|
579
|
+
# Update progress bar for remaining steps if episode ended early
|
|
580
|
+
if hasattr(config, '_pbar') and turn < config.max_turns - 1:
|
|
581
|
+
remaining_steps = config.max_turns - turn - 1
|
|
582
|
+
config._pbar.update(remaining_steps)
|
|
583
|
+
|
|
584
|
+
# Calculate invalid action rate
|
|
585
|
+
invalid_rate = invalid_actions / total_actions if total_actions > 0 else 0
|
|
586
|
+
|
|
587
|
+
# Calculate achievements
|
|
588
|
+
final_achievements = obs.get("achievements_status", {})
|
|
589
|
+
total_achievements = sum(1 for v in final_achievements.values() if v)
|
|
590
|
+
|
|
591
|
+
# Terminate environment
|
|
592
|
+
try:
|
|
593
|
+
await retry_http_request(
|
|
594
|
+
client, "POST", f"{config.crafter_service_url}/env/CrafterClassic/terminate",
|
|
595
|
+
json={"env_id": instance_id}
|
|
596
|
+
)
|
|
597
|
+
except Exception as e:
|
|
598
|
+
print(f" ā ļø Failed to terminate environment: {e}")
|
|
599
|
+
|
|
600
|
+
# End session
|
|
601
|
+
await session_tracer.end_session(save=config.save_traces)
|
|
602
|
+
# Close the tracer for this episode
|
|
603
|
+
await session_tracer.close()
|
|
604
|
+
|
|
605
|
+
return {
|
|
606
|
+
"model": model_name,
|
|
607
|
+
"episode": episode_num,
|
|
608
|
+
"total_achievements": total_achievements,
|
|
609
|
+
"achievements": final_achievements,
|
|
610
|
+
"invalid_action_rate": invalid_rate,
|
|
611
|
+
"total_actions": total_actions,
|
|
612
|
+
"invalid_actions": invalid_actions,
|
|
613
|
+
"session_id": session_id
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
except Exception as e:
|
|
617
|
+
print(f" ā Episode failed: {e}")
|
|
618
|
+
import traceback
|
|
619
|
+
traceback.print_exc()
|
|
620
|
+
|
|
621
|
+
# End session even if failed
|
|
622
|
+
await session_tracer.end_session(save=config.save_traces)
|
|
623
|
+
# Close the tracer for this episode
|
|
624
|
+
await session_tracer.close()
|
|
625
|
+
|
|
626
|
+
return {
|
|
627
|
+
"model": model_name,
|
|
628
|
+
"episode": episode_num,
|
|
629
|
+
"total_achievements": 0,
|
|
630
|
+
"achievements": {},
|
|
631
|
+
"invalid_action_rate": 1.0,
|
|
632
|
+
"total_actions": 0,
|
|
633
|
+
"invalid_actions": 0,
|
|
634
|
+
"session_id": session_id,
|
|
635
|
+
"error": str(e)
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
async def run_model_experiment(config: ExperimentConfig, model_name: str, experiment_id: str) -> list[dict[str, Any]]:
|
|
640
|
+
"""Run multiple episodes for a single model in parallel."""
|
|
641
|
+
print(f"\nš Running {config.num_episodes} episodes for {model_name} in parallel...\n")
|
|
642
|
+
|
|
643
|
+
# Create a progress bar for all steps across all episodes
|
|
644
|
+
total_steps = config.num_episodes * config.max_turns
|
|
645
|
+
pbar = atqdm(total=total_steps, desc=f"{model_name}", unit="steps", leave=True)
|
|
646
|
+
config._pbar = pbar # Store in config so episodes can update it
|
|
647
|
+
|
|
648
|
+
try:
|
|
649
|
+
# Create tasks for all episodes (each will create its own tracer)
|
|
650
|
+
tasks = []
|
|
651
|
+
for i in range(config.num_episodes):
|
|
652
|
+
task = run_episode(config, model_name, i, experiment_id)
|
|
653
|
+
tasks.append(task)
|
|
654
|
+
|
|
655
|
+
# Run all episodes in parallel
|
|
656
|
+
results = await asyncio.gather(*tasks)
|
|
657
|
+
|
|
658
|
+
# Calculate summary stats
|
|
659
|
+
successful_results = [r for r in results if "error" not in r]
|
|
660
|
+
if successful_results:
|
|
661
|
+
avg_achievements = sum(r["total_achievements"] for r in successful_results) / len(successful_results)
|
|
662
|
+
avg_invalid_rate = sum(r["invalid_action_rate"] for r in successful_results) / len(successful_results)
|
|
663
|
+
pbar.set_postfix({
|
|
664
|
+
"avg_achievements": f"{avg_achievements:.1f}",
|
|
665
|
+
"avg_invalid_rate": f"{avg_invalid_rate:.1%}",
|
|
666
|
+
"success_rate": f"{len(successful_results)}/{len(results)}"
|
|
667
|
+
})
|
|
668
|
+
finally:
|
|
669
|
+
pbar.close()
|
|
670
|
+
|
|
671
|
+
return results
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
async def analyze_results(config: ExperimentConfig, all_results: dict[str, list[dict[str, Any]]]):
|
|
675
|
+
"""Analyze results across all models using v3 database."""
|
|
676
|
+
print("\nš Analysis Results:")
|
|
677
|
+
print("=" * 80)
|
|
678
|
+
|
|
679
|
+
# Initialize database manager
|
|
680
|
+
db_manager = AsyncSQLTraceManager(config.database_url)
|
|
681
|
+
await db_manager.initialize()
|
|
682
|
+
|
|
683
|
+
try:
|
|
684
|
+
# Basic statistics by model
|
|
685
|
+
model_stats = {}
|
|
686
|
+
for model, results in all_results.items():
|
|
687
|
+
valid_results = [r for r in results if "error" not in r]
|
|
688
|
+
if valid_results:
|
|
689
|
+
achievements = [r["total_achievements"] for r in valid_results]
|
|
690
|
+
invalid_rates = [r["invalid_action_rate"] for r in valid_results]
|
|
691
|
+
|
|
692
|
+
model_stats[model] = {
|
|
693
|
+
"avg_achievements": np.mean(achievements),
|
|
694
|
+
"std_achievements": np.std(achievements),
|
|
695
|
+
"max_achievements": max(achievements),
|
|
696
|
+
"avg_invalid_rate": np.mean(invalid_rates),
|
|
697
|
+
"success_rate": len(valid_results) / len(results)
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
# Print model comparison
|
|
701
|
+
print("\nš Model Performance Summary:")
|
|
702
|
+
print(f"{'Model':<20} {'Avg Achievements':<18} {'Max Achievements':<18} {'Invalid Rate':<15} {'Success Rate':<15}")
|
|
703
|
+
print("-" * 86)
|
|
704
|
+
|
|
705
|
+
for model, stats in sorted(model_stats.items(), key=lambda x: x[1]["avg_achievements"], reverse=True):
|
|
706
|
+
print(f"{model:<20} {stats['avg_achievements']:>6.2f} ± {stats['std_achievements']:>4.2f} "
|
|
707
|
+
f"{stats['max_achievements']:>16} {stats['avg_invalid_rate']:>12.2%} {stats['success_rate']:>12.2%}")
|
|
708
|
+
|
|
709
|
+
# Achievement frequency analysis
|
|
710
|
+
print("\nš Achievement Frequencies:")
|
|
711
|
+
achievement_counts = defaultdict(lambda: defaultdict(int))
|
|
712
|
+
|
|
713
|
+
for model, results in all_results.items():
|
|
714
|
+
for result in results:
|
|
715
|
+
if "error" not in result:
|
|
716
|
+
for achievement, unlocked in result["achievements"].items():
|
|
717
|
+
if unlocked:
|
|
718
|
+
achievement_counts[model][achievement] += 1
|
|
719
|
+
|
|
720
|
+
# Get all unique achievements
|
|
721
|
+
all_achievements = set()
|
|
722
|
+
for model_achievements in achievement_counts.values():
|
|
723
|
+
all_achievements.update(model_achievements.keys())
|
|
724
|
+
|
|
725
|
+
# Print achievement table
|
|
726
|
+
if all_achievements:
|
|
727
|
+
print(f"\n{'Achievement':<25} " + " ".join(f"{model[:8]:>10}" for model in sorted(all_results.keys())))
|
|
728
|
+
print("-" * (25 + 11 * len(all_results)))
|
|
729
|
+
|
|
730
|
+
for achievement in sorted(all_achievements):
|
|
731
|
+
row = f"{achievement:<25}"
|
|
732
|
+
for model in sorted(all_results.keys()):
|
|
733
|
+
count = achievement_counts[model].get(achievement, 0)
|
|
734
|
+
total = len([r for r in all_results[model] if "error" not in r])
|
|
735
|
+
pct = (count / total * 100) if total > 0 else 0
|
|
736
|
+
row += f" {count:>3}/{total:<3} ({pct:>3.0f}%)"
|
|
737
|
+
print(row)
|
|
738
|
+
|
|
739
|
+
# Query model usage from database - filter to only show models used in this experiment
|
|
740
|
+
print("\nš° Model Usage Statistics from Current Experiment:")
|
|
741
|
+
model_usage_df = await db_manager.get_model_usage()
|
|
742
|
+
|
|
743
|
+
if model_usage_df is not None and not model_usage_df.empty:
|
|
744
|
+
# Filter to only show models from this experiment
|
|
745
|
+
experiment_models = set(all_results.keys())
|
|
746
|
+
filtered_df = model_usage_df[model_usage_df['model_name'].isin(experiment_models)]
|
|
747
|
+
|
|
748
|
+
if not filtered_df.empty:
|
|
749
|
+
# Format model usage statistics as table
|
|
750
|
+
print(f"{'Model':<20} {'Provider':<10} {'Usage Count':<12} {'Avg Latency (ms)':<18} {'Total Cost':<12}")
|
|
751
|
+
print("-" * 72)
|
|
752
|
+
for _, row in filtered_df.iterrows():
|
|
753
|
+
avg_latency = row['avg_latency_ms']
|
|
754
|
+
if pd.notna(avg_latency):
|
|
755
|
+
print(f"{row['model_name']:<20} {row['provider'] or 'N/A':<10} {row['usage_count']:<12} "
|
|
756
|
+
f"{avg_latency:<18.2f} ${row['total_cost_usd']:<11.4f}")
|
|
757
|
+
else:
|
|
758
|
+
print(f"{row['model_name']:<20} {row['provider'] or 'N/A':<10} {row['usage_count']:<12} "
|
|
759
|
+
f"{'N/A':<18} ${row['total_cost_usd']:<11.4f}")
|
|
760
|
+
|
|
761
|
+
# Export detailed results
|
|
762
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
763
|
+
results_file = f"crafter_experiment_results_{timestamp}.json"
|
|
764
|
+
|
|
765
|
+
with open(results_file, "w") as f:
|
|
766
|
+
json.dump({
|
|
767
|
+
"config": {
|
|
768
|
+
"num_episodes": config.num_episodes,
|
|
769
|
+
"max_turns": config.max_turns,
|
|
770
|
+
"difficulty": config.difficulty,
|
|
771
|
+
"models": list(all_results.keys())
|
|
772
|
+
},
|
|
773
|
+
"results": all_results,
|
|
774
|
+
"statistics": model_stats,
|
|
775
|
+
"timestamp": timestamp
|
|
776
|
+
}, f, indent=2)
|
|
777
|
+
|
|
778
|
+
print(f"\nš¾ Detailed results saved to: {results_file}")
|
|
779
|
+
|
|
780
|
+
finally:
|
|
781
|
+
await db_manager.close()
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
async def main():
|
|
785
|
+
"""Main entry point for the experiment."""
|
|
786
|
+
parser = argparse.ArgumentParser(description="Run Crafter experiments with multiple models")
|
|
787
|
+
parser.add_argument("--episodes", type=int, default=5, help="Number of episodes per model")
|
|
788
|
+
parser.add_argument("--max-turns", type=int, default=100, help="Maximum turns per episode")
|
|
789
|
+
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy", help="Game difficulty")
|
|
790
|
+
parser.add_argument("--models", nargs="+", default=MODELS_TO_TEST, help="Models to test")
|
|
791
|
+
parser.add_argument("--no-save", action="store_true", help="Don't save traces to database")
|
|
792
|
+
parser.add_argument("--quiet", action="store_true", help="Reduce output verbosity")
|
|
793
|
+
parser.add_argument("--db-url", default=DATABASE_URL, help="Database URL for tracing")
|
|
794
|
+
parser.add_argument("--base-seed", type=int, default=1000, help="Base seed for episodes (episodes use base_seed+episode_num)")
|
|
795
|
+
parser.add_argument("--turn-timeout", type=float, default=30.0, help="Timeout per turn in seconds")
|
|
796
|
+
parser.add_argument("--episode-timeout", type=float, default=300.0, help="Total timeout per episode in seconds")
|
|
797
|
+
|
|
798
|
+
args = parser.parse_args()
|
|
799
|
+
|
|
800
|
+
# Create configuration
|
|
801
|
+
config = ExperimentConfig()
|
|
802
|
+
config.num_episodes = args.episodes
|
|
803
|
+
config.max_turns = args.max_turns
|
|
804
|
+
config.difficulty = args.difficulty
|
|
805
|
+
config.save_traces = not args.no_save
|
|
806
|
+
config.verbose = not args.quiet
|
|
807
|
+
config.quiet = args.quiet
|
|
808
|
+
config.database_url = args.db_url
|
|
809
|
+
config.base_seed = args.base_seed
|
|
810
|
+
config.turn_timeout = args.turn_timeout
|
|
811
|
+
config.episode_timeout = args.episode_timeout
|
|
812
|
+
|
|
813
|
+
# Generate experiment ID
|
|
814
|
+
experiment_id = f"crafter_multi_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
815
|
+
|
|
816
|
+
print("š® Crafter Multi-Model Experiment")
|
|
817
|
+
print("=" * 50)
|
|
818
|
+
print(f"Experiment ID: {experiment_id}")
|
|
819
|
+
print(f"Models: {', '.join(args.models)}")
|
|
820
|
+
print(f"Episodes per model: {config.num_episodes}")
|
|
821
|
+
print(f"Max turns per episode: {config.max_turns}")
|
|
822
|
+
print(f"Difficulty: {config.difficulty}")
|
|
823
|
+
print(f"Seeds: {config.base_seed} to {config.base_seed + config.num_episodes - 1}")
|
|
824
|
+
print(f"Turn timeout: {config.turn_timeout}s")
|
|
825
|
+
print(f"Episode timeout: {config.episode_timeout}s")
|
|
826
|
+
print(f"Save traces: {config.save_traces}")
|
|
827
|
+
print(f"Database URL: {config.database_url}")
|
|
828
|
+
print("=" * 50)
|
|
829
|
+
|
|
830
|
+
# Check Crafter service
|
|
831
|
+
try:
|
|
832
|
+
async with httpx.AsyncClient() as client:
|
|
833
|
+
response = await client.get(f"{config.crafter_service_url}/health", timeout=5.0)
|
|
834
|
+
if response.status_code != 200:
|
|
835
|
+
print(f"ā Crafter service not healthy at {config.crafter_service_url}")
|
|
836
|
+
return
|
|
837
|
+
except Exception as e:
|
|
838
|
+
print(f"ā Cannot connect to Crafter service at {config.crafter_service_url}: {e}")
|
|
839
|
+
print("Please ensure the Crafter service is running.")
|
|
840
|
+
return
|
|
841
|
+
|
|
842
|
+
print("ā
Crafter service is running")
|
|
843
|
+
|
|
844
|
+
# Run experiments for each model
|
|
845
|
+
all_results = {}
|
|
846
|
+
|
|
847
|
+
for model in args.models:
|
|
848
|
+
results = await run_model_experiment(config, model, experiment_id)
|
|
849
|
+
all_results[model] = results
|
|
850
|
+
|
|
851
|
+
# Analyze and compare results
|
|
852
|
+
await analyze_results(config, all_results)
|
|
853
|
+
|
|
854
|
+
print("\nā
Experiment complete!")
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
if __name__ == "__main__":
|
|
858
|
+
asyncio.run(main())
|