synth-ai 0.2.2.dev0__py3-none-any.whl → 0.2.4.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/cli/__init__.py +66 -0
- synth_ai/cli/balance.py +205 -0
- synth_ai/cli/calc.py +70 -0
- synth_ai/cli/demo.py +74 -0
- synth_ai/{cli.py → cli/legacy_root_backup.py} +60 -15
- synth_ai/cli/man.py +103 -0
- synth_ai/cli/recent.py +126 -0
- synth_ai/cli/root.py +184 -0
- synth_ai/cli/status.py +126 -0
- synth_ai/cli/traces.py +136 -0
- synth_ai/cli/watch.py +508 -0
- synth_ai/config/base_url.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +252 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_duckdb_v2_backup.py +413 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +760 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_synth.py +34 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth.py +1740 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth_v2_backup.py +1318 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_duckdb_v2_backup.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v2_backup.py +1352 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +4 -4
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/test_crafter_react_agent_openai_v2_backup.py +2551 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +1 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/old/traces/session_crafter_episode_16_15227b68-2906-416f-acc4-d6a9b4fa5828_20250725_001154.json +1363 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +3 -3
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/crafter_custom/environment.py +1 -1
- synth_ai/environments/examples/enron/dataset/corbt___enron_emails_sample_questions/default/0.0.0/293c9fe8170037e01cc9cf5834e0cd5ef6f1a6bb/dataset_info.json +1 -0
- synth_ai/environments/examples/nethack/helpers/achievements.json +64 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +1 -1
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +5 -5
- synth_ai/environments/examples/red/units/test_movement_debug.py +2 -2
- synth_ai/environments/examples/red/units/test_retry_movement.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/available_envs.json +122 -0
- synth_ai/environments/examples/sokoban/verified_puzzles.json +54987 -0
- synth_ai/environments/service/core_routes.py +1 -1
- synth_ai/experimental/synth_oss.py +446 -0
- synth_ai/learning/core.py +21 -0
- synth_ai/learning/gateway.py +4 -0
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/mipro.py +8 -0
- synth_ai/lm/__init__.py +3 -0
- synth_ai/lm/core/main.py +4 -0
- synth_ai/lm/core/main_v3.py +238 -122
- synth_ai/lm/core/vendor_clients.py +4 -0
- synth_ai/lm/provider_support/openai.py +11 -2
- synth_ai/lm/vendors/base.py +7 -0
- synth_ai/lm/vendors/openai_standard.py +339 -4
- synth_ai/lm/vendors/openai_standard_responses.py +243 -0
- synth_ai/lm/vendors/synth_client.py +155 -5
- synth_ai/lm/warmup.py +54 -17
- synth_ai/tracing/__init__.py +18 -0
- synth_ai/tracing_v1/__init__.py +29 -14
- synth_ai/tracing_v3/__init__.py +2 -2
- synth_ai/tracing_v3/abstractions.py +62 -17
- synth_ai/tracing_v3/config.py +13 -7
- synth_ai/tracing_v3/db_config.py +6 -6
- synth_ai/tracing_v3/hooks.py +1 -1
- synth_ai/tracing_v3/llm_call_record_helpers.py +350 -0
- synth_ai/tracing_v3/lm_call_record_abstractions.py +257 -0
- synth_ai/tracing_v3/session_tracer.py +5 -5
- synth_ai/tracing_v3/tests/test_concurrent_operations.py +1 -1
- synth_ai/tracing_v3/tests/test_llm_call_records.py +672 -0
- synth_ai/tracing_v3/tests/test_session_tracer.py +43 -9
- synth_ai/tracing_v3/tests/test_turso_manager.py +1 -1
- synth_ai/tracing_v3/turso/manager.py +18 -11
- synth_ai/tracing_v3/turso/models.py +1 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/dashboard.py +329 -0
- synth_ai/v0/tracing/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/base_client.py +3 -3
- synth_ai/{tracing → v0/tracing}/client_manager.py +1 -1
- synth_ai/{tracing → v0/tracing}/context.py +1 -1
- synth_ai/{tracing → v0/tracing}/decorators.py +11 -11
- synth_ai/v0/tracing/events/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/events/manage.py +4 -4
- synth_ai/{tracing → v0/tracing}/events/scope.py +6 -6
- synth_ai/{tracing → v0/tracing}/events/store.py +3 -3
- synth_ai/{tracing → v0/tracing}/immediate_client.py +6 -6
- synth_ai/{tracing → v0/tracing}/log_client_base.py +2 -2
- synth_ai/{tracing → v0/tracing}/retry_queue.py +3 -3
- synth_ai/{tracing → v0/tracing}/trackers.py +2 -2
- synth_ai/{tracing → v0/tracing}/upload.py +4 -4
- synth_ai/v0/tracing_v1/__init__.py +16 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/base_client.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/client_manager.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/context.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/decorators.py +11 -11
- synth_ai/v0/tracing_v1/events/__init__.py +0 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/manage.py +4 -4
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/scope.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/store.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/immediate_client.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/log_client_base.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/retry_queue.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/trackers.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/upload.py +4 -4
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/METADATA +100 -5
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/RECORD +115 -75
- /synth_ai/{tracing/events/__init__.py → compound/cais.py} +0 -0
- /synth_ai/{tracing_v1/events/__init__.py → environments/examples/crafter_classic/debug_translation.py} +0 -0
- /synth_ai/{tracing → v0/tracing}/abstractions.py +0 -0
- /synth_ai/{tracing → v0/tracing}/config.py +0 -0
- /synth_ai/{tracing → v0/tracing}/local.py +0 -0
- /synth_ai/{tracing → v0/tracing}/utils.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/abstractions.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/config.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/local.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/utils.py +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1352 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Comprehensive script to run Crafter rollouts for multiple models and compare their performance.
|
4
|
+
|
5
|
+
Runs experiments for:
|
6
|
+
- gpt-4o-mini
|
7
|
+
- gpt-4.1-mini
|
8
|
+
- gpt-4.1-nano
|
9
|
+
- gemini-1.5-flash
|
10
|
+
- gemini-2.5-flash-lite
|
11
|
+
- qwen3/32b
|
12
|
+
|
13
|
+
Analyzes and compares:
|
14
|
+
- Invalid action rates
|
15
|
+
- Achievement frequencies by step
|
16
|
+
- Achievement counts across models
|
17
|
+
- Performance metrics
|
18
|
+
- Cost analysis
|
19
|
+
"""
|
20
|
+
|
21
|
+
import asyncio
|
22
|
+
import json
|
23
|
+
import uuid
|
24
|
+
import argparse
|
25
|
+
import logging
|
26
|
+
import time
|
27
|
+
import toml
|
28
|
+
from datetime import datetime
|
29
|
+
from typing import Dict, Any, Optional, List, Set, Tuple
|
30
|
+
from pathlib import Path
|
31
|
+
import sys
|
32
|
+
import os
|
33
|
+
from collections import defaultdict, Counter
|
34
|
+
import pandas as pd
|
35
|
+
import numpy as np
|
36
|
+
from tqdm.asyncio import tqdm_asyncio
|
37
|
+
from tqdm import tqdm
|
38
|
+
import duckdb
|
39
|
+
|
40
|
+
# Add parent directory to path for imports
|
41
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent.parent))
|
42
|
+
|
43
|
+
# Disable v1 logging to see v2 tracing clearly
|
44
|
+
os.environ["LANGFUSE_ENABLED"] = "false"
|
45
|
+
os.environ["SYNTH_LOGGING"] = "false"
|
46
|
+
|
47
|
+
# Import enhanced LM with v2 tracing
|
48
|
+
from synth_ai.lm.core.main_v2 import LM
|
49
|
+
|
50
|
+
# Import session tracer for v2 tracing
|
51
|
+
from synth_ai.tracing_v2.session_tracer import (
|
52
|
+
SessionTracer, SessionEventMessage, TimeRecord,
|
53
|
+
RuntimeEvent, EnvironmentEvent, LMCAISEvent
|
54
|
+
)
|
55
|
+
from synth_ai.tracing_v2.utils import create_experiment_context
|
56
|
+
from synth_ai.tracing_v2.duckdb.manager import DuckDBTraceManager
|
57
|
+
from synth_ai.tracing_v2.decorators import (
|
58
|
+
set_active_session_tracer, set_system_id, set_turn_number
|
59
|
+
)
|
60
|
+
|
61
|
+
# Import Crafter hooks
|
62
|
+
try:
|
63
|
+
from synth_ai.environments.examples.crafter_classic.trace_hooks import CRAFTER_HOOKS
|
64
|
+
print(f"✅ Loaded {len(CRAFTER_HOOKS)} Crafter achievement hooks (Easy, Medium, Hard)")
|
65
|
+
except ImportError:
|
66
|
+
print("Warning: Could not import CRAFTER_HOOKS")
|
67
|
+
CRAFTER_HOOKS = []
|
68
|
+
|
69
|
+
import httpx
|
70
|
+
import random
|
71
|
+
|
72
|
+
# Global buckets for sessions
|
73
|
+
_SESSIONS: dict[str, tuple[str, object]] = {} # session_id -> (experiment_id, trace)
|
74
|
+
|
75
|
+
# Configuration
|
76
|
+
MODELS_TO_TEST = [
|
77
|
+
"gpt-4o-mini",
|
78
|
+
#"gpt-4.1-mini",
|
79
|
+
"gpt-4.1-nano",
|
80
|
+
# "gemini-1.5-flash",
|
81
|
+
# "gemini-2.5-flash-lite",
|
82
|
+
]
|
83
|
+
|
84
|
+
# Service URLs (modify these based on your setup)
|
85
|
+
CRAFTER_SERVICE_URL = "http://localhost:8901"
|
86
|
+
DATABASE_PATH = "/Users/joshuapurtell/Documents/GitHub/synth-ai/synth_ai/traces/crafter_multi_model_traces.duckdb"
|
87
|
+
|
88
|
+
# Retry configuration for HTTP requests
|
89
|
+
MAX_RETRIES = 3
|
90
|
+
BASE_DELAY = 0.1
|
91
|
+
MAX_DELAY = 2.0
|
92
|
+
HTTP_TIMEOUT = 30.0
|
93
|
+
|
94
|
+
class ExperimentConfig:
|
95
|
+
"""Configuration for the multi-model experiment."""
|
96
|
+
|
97
|
+
def __init__(self):
|
98
|
+
self.num_episodes = 10 # Number of episodes per model
|
99
|
+
self.max_turns = 100 # Max turns per episode
|
100
|
+
self.difficulty = "easy"
|
101
|
+
self.save_traces = True
|
102
|
+
self.verbose = True
|
103
|
+
self.quiet = False # Default to verbose mode
|
104
|
+
self.enable_v2_tracing = True
|
105
|
+
self.v2_trace_dir = "./traces"
|
106
|
+
self.crafter_service_url = CRAFTER_SERVICE_URL
|
107
|
+
self.database_path = DATABASE_PATH
|
108
|
+
|
109
|
+
|
110
|
+
async def retry_http_request(client: httpx.AsyncClient, method: str, url: str, **kwargs) -> Any:
|
111
|
+
"""Retry HTTP requests with exponential backoff and jitter."""
|
112
|
+
last_exception = None
|
113
|
+
|
114
|
+
for attempt in range(MAX_RETRIES):
|
115
|
+
try:
|
116
|
+
if attempt > 0:
|
117
|
+
delay = min(BASE_DELAY * (2 ** (attempt - 1)), MAX_DELAY)
|
118
|
+
jitter = random.uniform(0, 0.1 * delay)
|
119
|
+
total_delay = delay + jitter
|
120
|
+
await asyncio.sleep(total_delay)
|
121
|
+
|
122
|
+
response = await client.request(method, url, timeout=HTTP_TIMEOUT, **kwargs)
|
123
|
+
|
124
|
+
if response.status_code < 500:
|
125
|
+
return response
|
126
|
+
|
127
|
+
last_exception = Exception(f"HTTP {response.status_code}: {response.text}")
|
128
|
+
|
129
|
+
except httpx.ConnectError as e:
|
130
|
+
last_exception = Exception(f"Connection failed to {url}: {e}")
|
131
|
+
if attempt < MAX_RETRIES - 1:
|
132
|
+
await asyncio.sleep(1.0 * (2 ** attempt))
|
133
|
+
except httpx.ReadError as e:
|
134
|
+
last_exception = e
|
135
|
+
if attempt < MAX_RETRIES - 1:
|
136
|
+
read_error_delay = min(1.0 * (2 ** attempt), 5.0)
|
137
|
+
await asyncio.sleep(read_error_delay)
|
138
|
+
except Exception as e:
|
139
|
+
last_exception = e
|
140
|
+
|
141
|
+
print(f" ❌ HTTP request failed after {MAX_RETRIES} attempts: {method} {url}")
|
142
|
+
print(f" ❌ Error: {type(last_exception).__name__}: {str(last_exception)[:200]}")
|
143
|
+
raise last_exception
|
144
|
+
|
145
|
+
|
146
|
+
# Crafter action mapping
|
147
|
+
CRAFTER_ACTIONS = {
|
148
|
+
"noop": 0, "move_left": 1, "move_right": 2, "move_up": 3, "move_down": 4,
|
149
|
+
"do": 5, "sleep": 6, "place_stone": 7, "place_table": 8, "place_furnace": 9,
|
150
|
+
"place_plant": 10, "make_wood_pickaxe": 11, "make_stone_pickaxe": 12,
|
151
|
+
"make_iron_pickaxe": 13, "make_wood_sword": 14, "make_stone_sword": 15,
|
152
|
+
"make_iron_sword": 16,
|
153
|
+
# Aliases
|
154
|
+
"move": 5, # "move" -> "do" (context-dependent action)
|
155
|
+
"collect": 5, # "collect" -> "do"
|
156
|
+
"attack": 5, # "attack" -> "do"
|
157
|
+
"eat": 5, # "eat" -> "do"
|
158
|
+
}
|
159
|
+
|
160
|
+
def action_to_int(action: str) -> int:
|
161
|
+
"""Convert action string to integer."""
|
162
|
+
normalized = action.strip().lower().replace(" ", "_")
|
163
|
+
return CRAFTER_ACTIONS.get(normalized, 5) # Default to "do"
|
164
|
+
|
165
|
+
|
166
|
+
def create_message(content: Any, message_type: str, origin_system_id: Any = None, turn: int = None) -> SessionEventMessage:
|
167
|
+
"""Create a session event message."""
|
168
|
+
return SessionEventMessage(
|
169
|
+
content=str(content),
|
170
|
+
message_type=message_type,
|
171
|
+
time_record=TimeRecord(
|
172
|
+
message_time=turn if turn is not None else 0,
|
173
|
+
event_time=time.time()
|
174
|
+
)
|
175
|
+
)
|
176
|
+
|
177
|
+
|
178
|
+
def compress_observation_for_trace(obs: Dict[str, Any]) -> Dict[str, Any]:
|
179
|
+
"""Compress observation for tracing."""
|
180
|
+
return {
|
181
|
+
"inventory": obs.get("inventory", {}),
|
182
|
+
"nearby": obs.get("nearby", []),
|
183
|
+
"status": obs.get("status", {}),
|
184
|
+
"achievement": obs.get("achievement", None)
|
185
|
+
}
|
186
|
+
|
187
|
+
|
188
|
+
async def run_episode(episode_id: int, model_name: str, config: ExperimentConfig,
|
189
|
+
session_tracer: SessionTracer) -> Dict[str, Any]:
|
190
|
+
"""Run a single episode with the specified model."""
|
191
|
+
episode_start_time = time.time()
|
192
|
+
episode_reward = 0.0
|
193
|
+
step_results = []
|
194
|
+
termination_reason = "max_steps"
|
195
|
+
|
196
|
+
# Set up LM for this model
|
197
|
+
lm = LM(
|
198
|
+
model_name=model_name,
|
199
|
+
formatting_model_name="gpt-4o-mini", # Use a reliable model for formatting
|
200
|
+
temperature=0.1, # Low temperature for more consistent gameplay
|
201
|
+
session_tracer=session_tracer,
|
202
|
+
system_id=f"crafter_agent_{model_name}",
|
203
|
+
enable_v2_tracing=True
|
204
|
+
)
|
205
|
+
|
206
|
+
# Create HTTP client
|
207
|
+
async with httpx.AsyncClient() as client:
|
208
|
+
try:
|
209
|
+
# Initialize environment
|
210
|
+
init_response = await retry_http_request(
|
211
|
+
client, "POST", f"{config.crafter_service_url}/env/CrafterClassic/initialize",
|
212
|
+
json={"difficulty": config.difficulty, "seed": random.randint(0, 1000000)}
|
213
|
+
)
|
214
|
+
init_data = init_response.json()
|
215
|
+
|
216
|
+
# Debug the response format
|
217
|
+
if config.verbose and not config.quiet:
|
218
|
+
print(f"Init response: {init_data}")
|
219
|
+
|
220
|
+
# Handle different possible response formats
|
221
|
+
if "env_id" in init_data:
|
222
|
+
instance_id = init_data["env_id"]
|
223
|
+
elif "instance_id" in init_data:
|
224
|
+
instance_id = init_data["instance_id"]
|
225
|
+
elif "id" in init_data:
|
226
|
+
instance_id = init_data["id"]
|
227
|
+
else:
|
228
|
+
# If none of the expected keys exist, print the response and raise a clear error
|
229
|
+
print(f"❌ Unexpected response format from Crafter service: {init_data}")
|
230
|
+
raise KeyError(f"Could not find environment ID in response. Available keys: {list(init_data.keys())}")
|
231
|
+
|
232
|
+
# Get initial observation (from initialize response)
|
233
|
+
obs = init_data["observation"]
|
234
|
+
|
235
|
+
prev_obs = obs
|
236
|
+
done = False
|
237
|
+
invalid_actions = 0
|
238
|
+
total_actions = 0
|
239
|
+
|
240
|
+
for turn in range(config.max_turns):
|
241
|
+
if done:
|
242
|
+
break
|
243
|
+
|
244
|
+
set_turn_number(turn)
|
245
|
+
|
246
|
+
# Start timestep for this turn
|
247
|
+
session_tracer.start_timestep(f"turn_{turn}")
|
248
|
+
|
249
|
+
# Prepare context for the agent
|
250
|
+
inventory_str = ", ".join([f"{k}: {v}" for k, v in obs.get("inventory", {}).items() if v > 0])
|
251
|
+
if not inventory_str:
|
252
|
+
inventory_str = "empty"
|
253
|
+
|
254
|
+
nearby_str = ", ".join(obs.get("nearby", []))
|
255
|
+
if not nearby_str:
|
256
|
+
nearby_str = "nothing"
|
257
|
+
|
258
|
+
status = obs.get("status", {})
|
259
|
+
health = status.get("health", 0)
|
260
|
+
hunger = status.get("food", 0)
|
261
|
+
|
262
|
+
# Create agent prompt
|
263
|
+
prompt = f"""You are playing Crafter, a 2D survival game. Choose your next action.
|
264
|
+
|
265
|
+
Current status:
|
266
|
+
- Health: {health}/9
|
267
|
+
- Hunger: {hunger}/9
|
268
|
+
- Inventory: {inventory_str}
|
269
|
+
- Nearby objects: {nearby_str}
|
270
|
+
|
271
|
+
Available actions: do, move_left, move_right, move_up, move_down, place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword, sleep
|
272
|
+
|
273
|
+
Respond with just the action name (e.g., "do" or "move_left" or "make_wood_pickaxe")."""
|
274
|
+
|
275
|
+
# Send observation as message
|
276
|
+
obs_msg = create_message(
|
277
|
+
compress_observation_for_trace(obs),
|
278
|
+
"observation",
|
279
|
+
f"crafter_env_{instance_id}",
|
280
|
+
turn
|
281
|
+
)
|
282
|
+
session_tracer.record_message(obs_msg)
|
283
|
+
|
284
|
+
# Get action from LM
|
285
|
+
try:
|
286
|
+
action_response = await lm.respond_async(
|
287
|
+
system_message="You are playing Crafter, a 2D survival game. Choose your next action.",
|
288
|
+
user_message=prompt,
|
289
|
+
turn_number=turn
|
290
|
+
)
|
291
|
+
action = action_response.raw_response.strip().lower()
|
292
|
+
|
293
|
+
# Clean up action
|
294
|
+
action = action.replace('"', '').replace("'", "").strip()
|
295
|
+
|
296
|
+
# Send action as message
|
297
|
+
action_msg = create_message(
|
298
|
+
action,
|
299
|
+
"action",
|
300
|
+
f"crafter_agent_{model_name}",
|
301
|
+
turn
|
302
|
+
)
|
303
|
+
session_tracer.record_message(action_msg)
|
304
|
+
|
305
|
+
except Exception as e:
|
306
|
+
if config.verbose and not config.quiet:
|
307
|
+
print(f" ❌ LM call failed: {e}")
|
308
|
+
action = "do" # Default action
|
309
|
+
|
310
|
+
total_actions += 1
|
311
|
+
|
312
|
+
# Convert action to integer and format correctly
|
313
|
+
action_int = action_to_int(action)
|
314
|
+
|
315
|
+
# Take action in environment
|
316
|
+
step_response = await retry_http_request(
|
317
|
+
client, "POST", f"{config.crafter_service_url}/env/CrafterClassic/step",
|
318
|
+
json={"env_id": instance_id, "action": {"tool_calls": [{"tool": "interact", "args": {"action": action_int}}]}}
|
319
|
+
)
|
320
|
+
step_data = step_response.json()
|
321
|
+
|
322
|
+
obs = step_data.get("observation", {})
|
323
|
+
reward = step_data.get("reward")
|
324
|
+
# Ensure reward is always a valid number
|
325
|
+
if reward is None or not isinstance(reward, (int, float)):
|
326
|
+
if config.verbose and not config.quiet:
|
327
|
+
print(f" ⚠️ Invalid reward {reward}, using 0.0")
|
328
|
+
reward = 0.0
|
329
|
+
else:
|
330
|
+
reward = float(reward)
|
331
|
+
|
332
|
+
done = step_data.get("done", False)
|
333
|
+
info = step_data.get("info", {})
|
334
|
+
|
335
|
+
# Check if action was invalid
|
336
|
+
if info.get("invalid_action", False):
|
337
|
+
invalid_actions += 1
|
338
|
+
|
339
|
+
episode_reward += reward
|
340
|
+
|
341
|
+
# Record step results
|
342
|
+
step_result = {
|
343
|
+
"step": turn,
|
344
|
+
"action": action,
|
345
|
+
"reward": reward,
|
346
|
+
"invalid": info.get("invalid_action", False),
|
347
|
+
"achievement": obs.get("achievement"),
|
348
|
+
"health": obs.get("status", {}).get("health", 0),
|
349
|
+
"hunger": obs.get("status", {}).get("food", 0)
|
350
|
+
}
|
351
|
+
step_results.append(step_result)
|
352
|
+
|
353
|
+
# Record runtime event
|
354
|
+
runtime_event = RuntimeEvent(
|
355
|
+
system_instance_id=f"crafter_runtime_{model_name}",
|
356
|
+
time_record=TimeRecord(
|
357
|
+
event_time=datetime.now().isoformat(),
|
358
|
+
message_time=turn
|
359
|
+
),
|
360
|
+
actions=[action_int],
|
361
|
+
metadata={
|
362
|
+
"step": turn,
|
363
|
+
"reward": reward,
|
364
|
+
"done": done,
|
365
|
+
"invalid_action": info.get("invalid_action", False),
|
366
|
+
"action_name": action,
|
367
|
+
"action_int": action_int
|
368
|
+
}
|
369
|
+
)
|
370
|
+
session_tracer.record_event(runtime_event)
|
371
|
+
|
372
|
+
if done:
|
373
|
+
termination_reason = "environment_done"
|
374
|
+
break
|
375
|
+
|
376
|
+
# Terminate instance
|
377
|
+
await retry_http_request(
|
378
|
+
client, "POST", f"{config.crafter_service_url}/env/CrafterClassic/terminate",
|
379
|
+
json={"env_id": instance_id}
|
380
|
+
)
|
381
|
+
|
382
|
+
except Exception as e:
|
383
|
+
print(f"❌ Episode {episode_id} failed: {e}")
|
384
|
+
import traceback
|
385
|
+
traceback.print_exc()
|
386
|
+
return {
|
387
|
+
"episode_id": episode_id,
|
388
|
+
"model": model_name,
|
389
|
+
"error": str(e),
|
390
|
+
"duration": time.time() - episode_start_time
|
391
|
+
}
|
392
|
+
|
393
|
+
# Calculate metrics
|
394
|
+
invalid_action_rate = invalid_actions / total_actions if total_actions > 0 else 0.0
|
395
|
+
|
396
|
+
return {
|
397
|
+
"episode_id": episode_id,
|
398
|
+
"model": model_name,
|
399
|
+
"total_reward": episode_reward,
|
400
|
+
"steps": len(step_results),
|
401
|
+
"termination_reason": termination_reason,
|
402
|
+
"duration": time.time() - episode_start_time,
|
403
|
+
"invalid_action_rate": invalid_action_rate,
|
404
|
+
"invalid_actions": invalid_actions,
|
405
|
+
"total_actions": total_actions,
|
406
|
+
"step_results": step_results
|
407
|
+
}
|
408
|
+
|
409
|
+
|
410
|
+
async def run_episode_async(episode_id: int, model_name: str, config: ExperimentConfig,
|
411
|
+
experiment_id: str) -> Dict[str, Any]:
|
412
|
+
"""Run a single episode asynchronously with its own tracer."""
|
413
|
+
# Create unique session ID with timestamp for better uniqueness
|
414
|
+
import time
|
415
|
+
timestamp = int(time.time() * 1000000) # microseconds for uniqueness
|
416
|
+
uuid_val = uuid.uuid4()
|
417
|
+
session_id = f"episode_{episode_id}_{model_name.replace('/', '_')}_{timestamp}_{uuid_val}"
|
418
|
+
|
419
|
+
# Debug session ID generation
|
420
|
+
print(f"🔧 Generated session_id: {session_id}")
|
421
|
+
print(f" Episode: {episode_id}, Model: {model_name}")
|
422
|
+
print(f" Timestamp: {timestamp}, UUID: {uuid_val}")
|
423
|
+
print(f" Model name sanitized: {model_name.replace('/', '_')}")
|
424
|
+
|
425
|
+
# Create individual tracer for this episode (no DB to avoid conflicts)
|
426
|
+
tracer = SessionTracer(hooks=CRAFTER_HOOKS, duckdb_path="")
|
427
|
+
|
428
|
+
# Add small delay to reduce database contention and ensure unique timestamps
|
429
|
+
await asyncio.sleep(0.01 * episode_id) # Staggered start times
|
430
|
+
|
431
|
+
# Additional delay to ensure timestamp uniqueness
|
432
|
+
await asyncio.sleep(0.001) # 1ms additional delay
|
433
|
+
|
434
|
+
tracer.start_session(session_id)
|
435
|
+
|
436
|
+
try:
|
437
|
+
# Run the episode
|
438
|
+
result = await run_episode(episode_id, model_name, config, tracer)
|
439
|
+
|
440
|
+
# Get reference to session before ending it
|
441
|
+
session_to_upload = tracer.current_session
|
442
|
+
|
443
|
+
# End session without uploading to DB (we'll do it at the end to avoid races)
|
444
|
+
trace_path = tracer.end_session(save=True, upload_to_db=False)
|
445
|
+
|
446
|
+
# Store session for batch upload at the end
|
447
|
+
if session_id in _SESSIONS:
|
448
|
+
print(f"⚠️ WARNING: Session {session_id} already in _SESSIONS! Skipping duplicate.")
|
449
|
+
print(f" Existing experiment_id: {_SESSIONS[session_id][0]}")
|
450
|
+
print(f" New experiment_id: {experiment_id}")
|
451
|
+
print(f" Existing trace type: {type(_SESSIONS[session_id][1])}")
|
452
|
+
print(f" New trace type: {type(session_to_upload)}")
|
453
|
+
print(f" This should NEVER happen with UUID-based session IDs!")
|
454
|
+
else:
|
455
|
+
_SESSIONS[session_id] = (experiment_id, session_to_upload)
|
456
|
+
print(f"🔵 Stored session {session_id} for batch upload")
|
457
|
+
print(f" Experiment ID: {experiment_id}")
|
458
|
+
print(f" Trace type: {type(session_to_upload)}")
|
459
|
+
if hasattr(session_to_upload, 'num_timesteps'):
|
460
|
+
print(f" Timesteps: {session_to_upload.num_timesteps}")
|
461
|
+
|
462
|
+
# Verify uniqueness by checking all existing session IDs
|
463
|
+
all_session_ids = list(_SESSIONS.keys())
|
464
|
+
if len(all_session_ids) != len(set(all_session_ids)):
|
465
|
+
print(f"🚨 CRITICAL: Session ID collision detected!")
|
466
|
+
print(f" Total sessions: {len(all_session_ids)}")
|
467
|
+
print(f" Unique sessions: {len(set(all_session_ids))}")
|
468
|
+
print(f" Collisions: {len(all_session_ids) - len(set(all_session_ids))}")
|
469
|
+
from collections import Counter
|
470
|
+
duplicates = [sid for sid, count in Counter(all_session_ids).items() if count > 1]
|
471
|
+
print(f" Duplicate IDs: {duplicates}")
|
472
|
+
return result
|
473
|
+
|
474
|
+
except Exception as e:
|
475
|
+
print(f"❌ Episode {episode_id} for {model_name} failed: {e}")
|
476
|
+
return {
|
477
|
+
"episode_id": episode_id,
|
478
|
+
"model": model_name,
|
479
|
+
"error": str(e),
|
480
|
+
"duration": 0.0
|
481
|
+
}
|
482
|
+
|
483
|
+
|
484
|
+
async def run_experiment_for_model(model_name: str, config: ExperimentConfig) -> Tuple[str, List[Dict[str, Any]]]:
|
485
|
+
"""Run complete experiment for a single model with concurrent episodes."""
|
486
|
+
if not config.quiet:
|
487
|
+
print(f"\n🚀 Starting experiment for {model_name}")
|
488
|
+
print(f" Episodes: {config.num_episodes}")
|
489
|
+
print(f" Max turns: {config.max_turns}")
|
490
|
+
|
491
|
+
# Create experiment ID
|
492
|
+
experiment_id = str(uuid.uuid4())
|
493
|
+
experiment_name = f"crafter_{model_name.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
494
|
+
|
495
|
+
# Create experiment in database
|
496
|
+
with DuckDBTraceManager(config.database_path) as db_manager:
|
497
|
+
try:
|
498
|
+
db_manager.create_experiment(
|
499
|
+
experiment_id=experiment_id,
|
500
|
+
name=experiment_name,
|
501
|
+
description=f"Crafter evaluation with {model_name}"
|
502
|
+
)
|
503
|
+
except Exception as e:
|
504
|
+
print(f"Warning: Could not create experiment in DB: {e}")
|
505
|
+
|
506
|
+
# Create async tasks for all episodes
|
507
|
+
episode_tasks = []
|
508
|
+
for i in range(config.num_episodes):
|
509
|
+
task = run_episode_async(i, model_name, config, experiment_id)
|
510
|
+
episode_tasks.append(task)
|
511
|
+
|
512
|
+
if not config.quiet:
|
513
|
+
print(f"📍 Running {config.num_episodes} episodes concurrently for {model_name}")
|
514
|
+
|
515
|
+
# Run all episodes concurrently with progress tracking
|
516
|
+
with tqdm(total=config.num_episodes, desc=f"{model_name} Episodes") as pbar:
|
517
|
+
results = []
|
518
|
+
|
519
|
+
# Use asyncio.as_completed to get results as they finish
|
520
|
+
for coro in asyncio.as_completed(episode_tasks):
|
521
|
+
result = await coro
|
522
|
+
results.append(result)
|
523
|
+
pbar.update(1)
|
524
|
+
|
525
|
+
# Sort results by episode_id to maintain order
|
526
|
+
results.sort(key=lambda x: x.get("episode_id", 0))
|
527
|
+
|
528
|
+
if not config.quiet:
|
529
|
+
print(f"✅ Completed experiment for {model_name}")
|
530
|
+
return experiment_id, results
|
531
|
+
|
532
|
+
|
533
|
+
def analyze_invalid_actions(all_results: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
|
534
|
+
"""Analyze invalid action rates across models."""
|
535
|
+
analysis = {}
|
536
|
+
|
537
|
+
for model_name, results in all_results.items():
|
538
|
+
successful_episodes = [r for r in results if "error" not in r]
|
539
|
+
|
540
|
+
if successful_episodes:
|
541
|
+
invalid_rates = [r["invalid_action_rate"] for r in successful_episodes]
|
542
|
+
total_invalid = sum(r["invalid_actions"] for r in successful_episodes)
|
543
|
+
total_actions = sum(r["total_actions"] for r in successful_episodes)
|
544
|
+
|
545
|
+
analysis[model_name] = {
|
546
|
+
"avg_invalid_rate": np.mean(invalid_rates),
|
547
|
+
"std_invalid_rate": np.std(invalid_rates),
|
548
|
+
"total_invalid_actions": total_invalid,
|
549
|
+
"total_actions": total_actions,
|
550
|
+
"overall_invalid_rate": total_invalid / total_actions if total_actions > 0 else 0.0
|
551
|
+
}
|
552
|
+
else:
|
553
|
+
analysis[model_name] = {
|
554
|
+
"avg_invalid_rate": 0.0,
|
555
|
+
"std_invalid_rate": 0.0,
|
556
|
+
"total_invalid_actions": 0,
|
557
|
+
"total_actions": 0,
|
558
|
+
"overall_invalid_rate": 0.0
|
559
|
+
}
|
560
|
+
|
561
|
+
return analysis
|
562
|
+
|
563
|
+
|
564
|
+
def analyze_achievements_by_step(all_results: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
|
565
|
+
"""Analyze achievement frequencies by step across models."""
|
566
|
+
analysis = {}
|
567
|
+
|
568
|
+
for model_name, results in all_results.items():
|
569
|
+
successful_episodes = [r for r in results if "error" not in r]
|
570
|
+
|
571
|
+
achievement_by_step = defaultdict(list)
|
572
|
+
all_achievements = []
|
573
|
+
|
574
|
+
for result in successful_episodes:
|
575
|
+
for step_result in result.get("step_results", []):
|
576
|
+
step = step_result["step"]
|
577
|
+
achievement = step_result.get("achievement")
|
578
|
+
|
579
|
+
if achievement:
|
580
|
+
achievement_by_step[step].append(achievement)
|
581
|
+
all_achievements.append(achievement)
|
582
|
+
|
583
|
+
# Count unique achievements
|
584
|
+
achievement_counts = Counter(all_achievements)
|
585
|
+
|
586
|
+
# Calculate achievement frequency by step ranges
|
587
|
+
step_ranges = [(0, 25), (26, 50), (51, 75), (76, 100)]
|
588
|
+
achievements_by_range = {}
|
589
|
+
|
590
|
+
for range_start, range_end in step_ranges:
|
591
|
+
range_achievements = []
|
592
|
+
for step in range(range_start, range_end + 1):
|
593
|
+
range_achievements.extend(achievement_by_step.get(step, []))
|
594
|
+
|
595
|
+
achievements_by_range[f"{range_start}-{range_end}"] = {
|
596
|
+
"count": len(range_achievements),
|
597
|
+
"unique": len(set(range_achievements)),
|
598
|
+
"achievements": list(set(range_achievements))
|
599
|
+
}
|
600
|
+
|
601
|
+
analysis[model_name] = {
|
602
|
+
"total_achievements": len(all_achievements),
|
603
|
+
"unique_achievements": len(set(all_achievements)),
|
604
|
+
"achievement_counts": dict(achievement_counts),
|
605
|
+
"achievements_by_step_range": achievements_by_range
|
606
|
+
}
|
607
|
+
|
608
|
+
return analysis
|
609
|
+
|
610
|
+
|
611
|
+
def analyze_performance_metrics(all_results: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
|
612
|
+
"""Analyze overall performance metrics across models."""
|
613
|
+
analysis = {}
|
614
|
+
|
615
|
+
for model_name, results in all_results.items():
|
616
|
+
successful_episodes = [r for r in results if "error" not in r]
|
617
|
+
failed_episodes = [r for r in results if "error" in r]
|
618
|
+
|
619
|
+
if successful_episodes:
|
620
|
+
rewards = [r["total_reward"] for r in successful_episodes]
|
621
|
+
steps = [r["steps"] for r in successful_episodes]
|
622
|
+
durations = [r["duration"] for r in successful_episodes]
|
623
|
+
|
624
|
+
analysis[model_name] = {
|
625
|
+
"total_episodes": len(results),
|
626
|
+
"successful_episodes": len(successful_episodes),
|
627
|
+
"failed_episodes": len(failed_episodes),
|
628
|
+
"success_rate": len(successful_episodes) / len(results),
|
629
|
+
"avg_reward": np.mean(rewards),
|
630
|
+
"std_reward": np.std(rewards),
|
631
|
+
"max_reward": np.max(rewards),
|
632
|
+
"min_reward": np.min(rewards),
|
633
|
+
"avg_steps": np.mean(steps),
|
634
|
+
"avg_duration": np.mean(durations)
|
635
|
+
}
|
636
|
+
else:
|
637
|
+
analysis[model_name] = {
|
638
|
+
"total_episodes": len(results),
|
639
|
+
"successful_episodes": 0,
|
640
|
+
"failed_episodes": len(failed_episodes),
|
641
|
+
"success_rate": 0.0,
|
642
|
+
"avg_reward": 0.0,
|
643
|
+
"std_reward": 0.0,
|
644
|
+
"max_reward": 0.0,
|
645
|
+
"min_reward": 0.0,
|
646
|
+
"avg_steps": 0.0,
|
647
|
+
"avg_duration": 0.0
|
648
|
+
}
|
649
|
+
|
650
|
+
return analysis
|
651
|
+
|
652
|
+
|
653
|
+
def print_results_summary(all_results: Dict[str, List[Dict[str, Any]]],
|
654
|
+
experiment_ids: Dict[str, str]):
|
655
|
+
"""Print comprehensive results summary."""
|
656
|
+
print("\n" + "="*100)
|
657
|
+
print("🏆 MULTI-MODEL CRAFTER EVALUATION RESULTS")
|
658
|
+
print("="*100)
|
659
|
+
|
660
|
+
# Performance metrics
|
661
|
+
performance_analysis = analyze_performance_metrics(all_results)
|
662
|
+
|
663
|
+
print("\n📊 PERFORMANCE SUMMARY")
|
664
|
+
print("-" * 80)
|
665
|
+
print(f"{'Model':<20} {'Episodes':<10} {'Success%':<10} {'Avg Reward':<12} {'Avg Steps':<12} {'Avg Duration':<12}")
|
666
|
+
print("-" * 80)
|
667
|
+
|
668
|
+
for model_name in MODELS_TO_TEST:
|
669
|
+
if model_name in performance_analysis:
|
670
|
+
perf = performance_analysis[model_name]
|
671
|
+
print(f"{model_name:<20} {perf['total_episodes']:<10} {perf['success_rate']*100:<9.1f}% {perf['avg_reward']:<11.2f} {perf['avg_steps']:<11.1f} {perf['avg_duration']:<11.1f}s")
|
672
|
+
|
673
|
+
# Invalid action analysis
|
674
|
+
invalid_analysis = analyze_invalid_actions(all_results)
|
675
|
+
|
676
|
+
print("\n🚫 INVALID ACTION ANALYSIS")
|
677
|
+
print("-" * 80)
|
678
|
+
print(f"{'Model':<20} {'Avg Invalid%':<15} {'Total Invalid':<15} {'Total Actions':<15}")
|
679
|
+
print("-" * 80)
|
680
|
+
|
681
|
+
for model_name in MODELS_TO_TEST:
|
682
|
+
if model_name in invalid_analysis:
|
683
|
+
inv = invalid_analysis[model_name]
|
684
|
+
print(f"{model_name:<20} {inv['avg_invalid_rate']*100:<14.2f}% {inv['total_invalid_actions']:<14} {inv['total_actions']:<14}")
|
685
|
+
|
686
|
+
# Achievement analysis
|
687
|
+
achievement_analysis = analyze_achievements_by_step(all_results)
|
688
|
+
|
689
|
+
print("\n🏅 ACHIEVEMENT ANALYSIS")
|
690
|
+
print("-" * 80)
|
691
|
+
print(f"{'Model':<20} {'Total Ach.':<12} {'Unique Ach.':<12} {'Early (0-25)':<12} {'Mid (26-50)':<12} {'Late (51+)':<12}")
|
692
|
+
print("-" * 80)
|
693
|
+
|
694
|
+
for model_name in MODELS_TO_TEST:
|
695
|
+
if model_name in achievement_analysis:
|
696
|
+
ach = achievement_analysis[model_name]
|
697
|
+
early = ach['achievements_by_step_range'].get('0-25', {}).get('count', 0)
|
698
|
+
mid = ach['achievements_by_step_range'].get('26-50', {}).get('count', 0)
|
699
|
+
late1 = ach['achievements_by_step_range'].get('51-75', {}).get('count', 0)
|
700
|
+
late2 = ach['achievements_by_step_range'].get('76-100', {}).get('count', 0)
|
701
|
+
late = late1 + late2
|
702
|
+
|
703
|
+
print(f"{model_name:<20} {ach['total_achievements']:<11} {ach['unique_achievements']:<11} {early:<11} {mid:<11} {late:<11}")
|
704
|
+
|
705
|
+
# Model ranking
|
706
|
+
print("\n🥇 MODEL RANKINGS")
|
707
|
+
print("-" * 50)
|
708
|
+
|
709
|
+
# Rank by average reward
|
710
|
+
reward_ranking = sorted([(model, perf['avg_reward']) for model, perf in performance_analysis.items()],
|
711
|
+
key=lambda x: x[1], reverse=True)
|
712
|
+
|
713
|
+
print("By Average Reward:")
|
714
|
+
for i, (model, reward) in enumerate(reward_ranking, 1):
|
715
|
+
print(f" {i}. {model}: {reward:.2f}")
|
716
|
+
|
717
|
+
# Rank by invalid action rate (lower is better)
|
718
|
+
invalid_ranking = sorted([(model, inv['avg_invalid_rate']) for model, inv in invalid_analysis.items()],
|
719
|
+
key=lambda x: x[1])
|
720
|
+
|
721
|
+
print("\nBy Invalid Action Rate (lower is better):")
|
722
|
+
for i, (model, rate) in enumerate(invalid_ranking, 1):
|
723
|
+
print(f" {i}. {model}: {rate*100:.2f}%")
|
724
|
+
|
725
|
+
# Experiment IDs
|
726
|
+
print("\n🆔 EXPERIMENT IDS")
|
727
|
+
print("-" * 50)
|
728
|
+
for model_name, exp_id in experiment_ids.items():
|
729
|
+
print(f"{model_name}: {exp_id}")
|
730
|
+
|
731
|
+
print("\n" + "="*100)
|
732
|
+
|
733
|
+
|
734
|
+
def print_comprehensive_model_analytics(database_path: str, experiment_ids: Dict[str, str], quiet: bool = False):
|
735
|
+
"""Generate comprehensive model analytics from DuckDB data."""
|
736
|
+
try:
|
737
|
+
from synth_ai.tracing_v2.duckdb.manager import DuckDBTraceManager
|
738
|
+
import pandas as pd
|
739
|
+
|
740
|
+
with DuckDBTraceManager(database_path) as db:
|
741
|
+
if not quiet:
|
742
|
+
print("\n🔍 COMPREHENSIVE MODEL ANALYTICS")
|
743
|
+
print("=" * 80)
|
744
|
+
|
745
|
+
# 1. Model Performance Summary
|
746
|
+
print_model_performance_summary(db, experiment_ids, quiet)
|
747
|
+
|
748
|
+
# 2. Achievement Analysis
|
749
|
+
print_achievement_analysis(db, experiment_ids, quiet)
|
750
|
+
|
751
|
+
# 3. Action Analysis
|
752
|
+
print_action_analysis(db, experiment_ids, quiet)
|
753
|
+
|
754
|
+
# 4. Efficiency Metrics
|
755
|
+
print_efficiency_metrics(db, experiment_ids, quiet)
|
756
|
+
|
757
|
+
# 5. Error Analysis
|
758
|
+
print_error_analysis(db, experiment_ids, quiet)
|
759
|
+
|
760
|
+
except Exception as e:
|
761
|
+
if not quiet:
|
762
|
+
print(f"⚠️ Failed to generate analytics: {e}")
|
763
|
+
|
764
|
+
|
765
|
+
def print_model_performance_summary(db, experiment_ids: Dict[str, str], quiet: bool):
|
766
|
+
"""Print overall model performance summary."""
|
767
|
+
if not quiet:
|
768
|
+
print("\n## 📊 Model Performance Summary")
|
769
|
+
print("-" * 40)
|
770
|
+
|
771
|
+
try:
|
772
|
+
# Get session-level metrics by model (using reward from runtime metadata)
|
773
|
+
valid_experiment_ids = [eid for eid in experiment_ids.values() if eid != "failed"]
|
774
|
+
if not valid_experiment_ids:
|
775
|
+
print("No performance data available")
|
776
|
+
return
|
777
|
+
|
778
|
+
query = """
|
779
|
+
WITH session_rewards AS (
|
780
|
+
SELECT
|
781
|
+
st.session_id,
|
782
|
+
st.experiment_id,
|
783
|
+
st.num_timesteps,
|
784
|
+
SUM(COALESCE(CAST(json_extract(ev.metadata, '$.reward') AS DOUBLE), 0)) as session_total_reward
|
785
|
+
FROM session_traces st
|
786
|
+
LEFT JOIN events ev ON st.session_id = ev.session_id
|
787
|
+
AND ev.event_type = 'runtime'
|
788
|
+
AND json_extract(ev.metadata, '$.reward') IS NOT NULL
|
789
|
+
GROUP BY st.session_id, st.experiment_id, st.num_timesteps
|
790
|
+
)
|
791
|
+
SELECT
|
792
|
+
experiment_id,
|
793
|
+
COUNT(session_id) as episodes,
|
794
|
+
AVG(num_timesteps) as avg_steps,
|
795
|
+
MAX(num_timesteps) as max_steps,
|
796
|
+
MIN(num_timesteps) as min_steps,
|
797
|
+
AVG(session_total_reward) as avg_reward,
|
798
|
+
SUM(session_total_reward) as total_reward
|
799
|
+
FROM session_rewards
|
800
|
+
WHERE experiment_id IN ({})
|
801
|
+
GROUP BY experiment_id
|
802
|
+
ORDER BY total_reward DESC
|
803
|
+
""".format(','.join([f"'{eid}'" for eid in valid_experiment_ids]))
|
804
|
+
|
805
|
+
df = db.query_traces(query)
|
806
|
+
|
807
|
+
if df.empty:
|
808
|
+
print("No performance data available")
|
809
|
+
return
|
810
|
+
|
811
|
+
# Map experiment IDs back to model names
|
812
|
+
exp_to_model = {v: k for k, v in experiment_ids.items() if v != "failed"}
|
813
|
+
df['model'] = df['experiment_id'].map(exp_to_model)
|
814
|
+
|
815
|
+
# Create performance table
|
816
|
+
if not quiet:
|
817
|
+
print("\n| Model | Episodes | Avg Steps | Max Steps | Total Reward | Avg Reward |")
|
818
|
+
print("|-------|----------|-----------|-----------|--------------|------------|")
|
819
|
+
|
820
|
+
for _, row in df.iterrows():
|
821
|
+
print(f"| {row['model']:<12} | {int(row['episodes']):>8} | {row['avg_steps']:>9.1f} | {int(row['max_steps']):>9} | {row['total_reward']:>12.1f} | {row['avg_reward']:>10.3f} |")
|
822
|
+
else:
|
823
|
+
# Quiet mode - just show winners
|
824
|
+
if not df.empty and 'total_reward' in df.columns and 'avg_reward' in df.columns:
|
825
|
+
valid_total_df = df[df['total_reward'].notna()]
|
826
|
+
valid_avg_df = df[df['avg_reward'].notna()]
|
827
|
+
|
828
|
+
if not valid_total_df.empty:
|
829
|
+
best_reward = valid_total_df.loc[valid_total_df['total_reward'].idxmax()]
|
830
|
+
print(f"🏆 Best Total Reward: {best_reward['model']} ({best_reward['total_reward']:.1f})")
|
831
|
+
|
832
|
+
if not valid_avg_df.empty:
|
833
|
+
best_efficiency = valid_avg_df.loc[valid_avg_df['avg_reward'].idxmax()]
|
834
|
+
print(f"⚡ Most Efficient: {best_efficiency['model']} ({best_efficiency['avg_reward']:.3f} avg reward)")
|
835
|
+
|
836
|
+
except Exception as e:
|
837
|
+
if not quiet:
|
838
|
+
print(f"Failed to get performance summary: {e}")
|
839
|
+
|
840
|
+
|
841
|
+
def print_achievement_analysis(db, experiment_ids: Dict[str, str], quiet: bool):
|
842
|
+
"""Analyze achievement patterns across models."""
|
843
|
+
if not quiet:
|
844
|
+
print("\n## 🏆 Achievement Analysis")
|
845
|
+
print("-" * 30)
|
846
|
+
|
847
|
+
try:
|
848
|
+
# Get achievement counts by model (simplified approach looking in event_metadata)
|
849
|
+
valid_experiment_ids = [eid for eid in experiment_ids.values() if eid != "failed"]
|
850
|
+
if not valid_experiment_ids:
|
851
|
+
print("No achievement data available")
|
852
|
+
return
|
853
|
+
|
854
|
+
query = """
|
855
|
+
SELECT
|
856
|
+
st.experiment_id,
|
857
|
+
COALESCE(
|
858
|
+
json_extract(ev.metadata, '$.achievement'),
|
859
|
+
json_extract(ev.event_metadata, '$[0].achievement'),
|
860
|
+
'generic_achievement'
|
861
|
+
) as achievement,
|
862
|
+
'unknown' as difficulty,
|
863
|
+
COUNT(*) as achievement_count
|
864
|
+
FROM session_traces st
|
865
|
+
JOIN events ev ON st.session_id = ev.session_id
|
866
|
+
WHERE st.experiment_id IN ({})
|
867
|
+
AND ev.event_type = 'runtime'
|
868
|
+
AND (
|
869
|
+
json_extract(ev.metadata, '$.achievement') IS NOT NULL
|
870
|
+
OR (ev.event_metadata IS NOT NULL
|
871
|
+
AND ev.event_metadata != '[]'
|
872
|
+
AND ev.event_metadata LIKE '%achievement%')
|
873
|
+
)
|
874
|
+
GROUP BY st.experiment_id, achievement
|
875
|
+
ORDER BY achievement_count DESC
|
876
|
+
""".format(','.join([f"'{eid}'" for eid in valid_experiment_ids]))
|
877
|
+
|
878
|
+
df = db.query_traces(query)
|
879
|
+
|
880
|
+
if df.empty:
|
881
|
+
print("No achievement data available")
|
882
|
+
return
|
883
|
+
|
884
|
+
# Map experiment IDs back to model names
|
885
|
+
exp_to_model = {v: k for k, v in experiment_ids.items() if v != "failed"}
|
886
|
+
df['model'] = df['experiment_id'].map(exp_to_model)
|
887
|
+
|
888
|
+
# Pivot table for achievements by model
|
889
|
+
pivot = df.pivot_table(index='achievement', columns='model', values='achievement_count', fill_value=0)
|
890
|
+
|
891
|
+
if not quiet:
|
892
|
+
print("\n### Achievement Counts by Model:")
|
893
|
+
print(pivot.to_string())
|
894
|
+
|
895
|
+
# Show top achievements
|
896
|
+
total_achievements = df.groupby('achievement')['achievement_count'].sum().sort_values(ascending=False)
|
897
|
+
print(f"\n### Most Common Achievements:")
|
898
|
+
for i, (achievement, count) in enumerate(total_achievements.head(5).items()):
|
899
|
+
print(f"{i+1}. {achievement}: {count} times")
|
900
|
+
else:
|
901
|
+
# Show just the winners
|
902
|
+
total_by_model = df.groupby('model')['achievement_count'].sum().sort_values(ascending=False)
|
903
|
+
if not total_by_model.empty:
|
904
|
+
best_model = total_by_model.index[0]
|
905
|
+
print(f"🏆 Most Achievements: {best_model} ({total_by_model.iloc[0]} total)")
|
906
|
+
|
907
|
+
except Exception as e:
|
908
|
+
if not quiet:
|
909
|
+
print(f"Failed to analyze achievements: {e}")
|
910
|
+
|
911
|
+
|
912
|
+
def print_action_analysis(db, experiment_ids: Dict[str, str], quiet: bool):
|
913
|
+
"""Analyze action patterns and invalid actions."""
|
914
|
+
if not quiet:
|
915
|
+
print("\n## 🎮 Action Analysis")
|
916
|
+
print("-" * 25)
|
917
|
+
|
918
|
+
try:
|
919
|
+
# Get invalid action rates by model (simplified approach)
|
920
|
+
valid_experiment_ids = [eid for eid in experiment_ids.values() if eid != "failed"]
|
921
|
+
if not valid_experiment_ids:
|
922
|
+
print("No action data available")
|
923
|
+
return
|
924
|
+
|
925
|
+
query = """
|
926
|
+
SELECT
|
927
|
+
st.experiment_id,
|
928
|
+
COUNT(*) as total_actions,
|
929
|
+
SUM(CASE
|
930
|
+
WHEN COALESCE(
|
931
|
+
CAST(json_extract(ev.metadata, '$.invalid_action') AS BOOLEAN),
|
932
|
+
ev.event_metadata LIKE '%invalid_action%'
|
933
|
+
) THEN 1
|
934
|
+
ELSE 0
|
935
|
+
END) as invalid_actions,
|
936
|
+
ROUND(100.0 * SUM(CASE
|
937
|
+
WHEN COALESCE(
|
938
|
+
CAST(json_extract(ev.metadata, '$.invalid_action') AS BOOLEAN),
|
939
|
+
ev.event_metadata LIKE '%invalid_action%'
|
940
|
+
) THEN 1
|
941
|
+
ELSE 0
|
942
|
+
END) / NULLIF(COUNT(*), 0), 2) as invalid_rate
|
943
|
+
FROM session_traces st
|
944
|
+
JOIN events ev ON st.session_id = ev.session_id
|
945
|
+
WHERE st.experiment_id IN ({})
|
946
|
+
AND ev.event_type = 'runtime'
|
947
|
+
AND json_extract(ev.metadata, '$.step') IS NOT NULL
|
948
|
+
GROUP BY st.experiment_id
|
949
|
+
ORDER BY invalid_rate ASC
|
950
|
+
""".format(','.join([f"'{eid}'" for eid in valid_experiment_ids]))
|
951
|
+
|
952
|
+
df = db.query_traces(query)
|
953
|
+
|
954
|
+
if df.empty:
|
955
|
+
print("No action data available")
|
956
|
+
return
|
957
|
+
|
958
|
+
# Map experiment IDs back to model names
|
959
|
+
exp_to_model = {v: k for k, v in experiment_ids.items() if v != "failed"}
|
960
|
+
df['model'] = df['experiment_id'].map(exp_to_model)
|
961
|
+
|
962
|
+
if not quiet:
|
963
|
+
print("\n| Model | Total Actions | Invalid Actions | Invalid Rate |")
|
964
|
+
print("|-------|---------------|-----------------|--------------|")
|
965
|
+
|
966
|
+
for _, row in df.iterrows():
|
967
|
+
print(f"| {row['model']:<12} | {int(row['total_actions']):>13} | {int(row['invalid_actions']):>15} | {row['invalid_rate']:>10.1f}% |")
|
968
|
+
else:
|
969
|
+
# Show best and worst
|
970
|
+
if not df.empty and 'invalid_rate' in df.columns:
|
971
|
+
valid_df = df[df['invalid_rate'].notna()]
|
972
|
+
if not valid_df.empty:
|
973
|
+
best_model = valid_df.loc[valid_df['invalid_rate'].idxmin()]
|
974
|
+
worst_model = valid_df.loc[valid_df['invalid_rate'].idxmax()]
|
975
|
+
print(f"🎯 Most Accurate: {best_model['model']} ({best_model['invalid_rate']:.1f}% invalid)")
|
976
|
+
print(f"❌ Least Accurate: {worst_model['model']} ({worst_model['invalid_rate']:.1f}% invalid)")
|
977
|
+
|
978
|
+
except Exception as e:
|
979
|
+
if not quiet:
|
980
|
+
print(f"Failed to analyze actions: {e}")
|
981
|
+
|
982
|
+
|
983
|
+
def print_efficiency_metrics(db, experiment_ids: Dict[str, str], quiet: bool):
|
984
|
+
"""Analyze efficiency metrics like tokens and cost."""
|
985
|
+
if not quiet:
|
986
|
+
print("\n## ⚡ Efficiency Metrics")
|
987
|
+
print("-" * 25)
|
988
|
+
|
989
|
+
try:
|
990
|
+
# Get token usage and cost by model (look for events with LLM data)
|
991
|
+
valid_experiment_ids = [eid for eid in experiment_ids.values() if eid != "failed"]
|
992
|
+
if not valid_experiment_ids:
|
993
|
+
print("No efficiency data available")
|
994
|
+
return
|
995
|
+
|
996
|
+
query = """
|
997
|
+
SELECT
|
998
|
+
st.experiment_id,
|
999
|
+
COUNT(*) as llm_calls,
|
1000
|
+
AVG(ev.prompt_tokens) as avg_prompt_tokens,
|
1001
|
+
AVG(ev.completion_tokens) as avg_completion_tokens,
|
1002
|
+
SUM(ev.total_tokens) as total_tokens,
|
1003
|
+
AVG(ev.latency_ms) as avg_latency_ms,
|
1004
|
+
SUM(COALESCE(ev.cost, 0)) as total_cost
|
1005
|
+
FROM session_traces st
|
1006
|
+
JOIN events ev ON st.session_id = ev.session_id
|
1007
|
+
WHERE st.experiment_id IN ({})
|
1008
|
+
AND ev.model_name IS NOT NULL
|
1009
|
+
AND ev.prompt_tokens IS NOT NULL
|
1010
|
+
GROUP BY st.experiment_id
|
1011
|
+
ORDER BY total_cost ASC
|
1012
|
+
""".format(','.join([f"'{eid}'" for eid in valid_experiment_ids]))
|
1013
|
+
|
1014
|
+
df = db.query_traces(query)
|
1015
|
+
|
1016
|
+
if df.empty:
|
1017
|
+
print("No efficiency data available")
|
1018
|
+
return
|
1019
|
+
|
1020
|
+
# Map experiment IDs back to model names
|
1021
|
+
exp_to_model = {v: k for k, v in experiment_ids.items() if v != "failed"}
|
1022
|
+
df['model'] = df['experiment_id'].map(exp_to_model)
|
1023
|
+
|
1024
|
+
if not quiet:
|
1025
|
+
print("\n| Model | LLM Calls | Avg Prompt | Avg Completion | Total Tokens | Avg Latency | Total Cost |")
|
1026
|
+
print("|-------|-----------|------------|----------------|--------------|-------------|------------|")
|
1027
|
+
|
1028
|
+
for _, row in df.iterrows():
|
1029
|
+
cost = row['total_cost'] if pd.notna(row['total_cost']) else 0.0
|
1030
|
+
latency = row['avg_latency_ms'] if pd.notna(row['avg_latency_ms']) else 0.0
|
1031
|
+
prompt_tokens = row['avg_prompt_tokens'] if pd.notna(row['avg_prompt_tokens']) else 0.0
|
1032
|
+
completion_tokens = row['avg_completion_tokens'] if pd.notna(row['avg_completion_tokens']) else 0.0
|
1033
|
+
total_tokens = row['total_tokens'] if pd.notna(row['total_tokens']) else 0
|
1034
|
+
llm_calls = row['llm_calls'] if pd.notna(row['llm_calls']) else 0
|
1035
|
+
print(f"| {row['model']:<8} | {int(llm_calls):>9} | {prompt_tokens:>10.0f} | {completion_tokens:>14.0f} | {int(total_tokens):>12} | {latency:>9.0f}ms | ${cost:>9.4f} |")
|
1036
|
+
else:
|
1037
|
+
# Show most efficient
|
1038
|
+
if 'total_cost' in df.columns and not df['total_cost'].isna().all():
|
1039
|
+
valid_cost_df = df[df['total_cost'].notna() & (df['total_cost'] > 0)]
|
1040
|
+
if not valid_cost_df.empty:
|
1041
|
+
cheapest = valid_cost_df.loc[valid_cost_df['total_cost'].idxmin()]
|
1042
|
+
print(f"💰 Most Cost-Efficient: {cheapest['model']} (${cheapest['total_cost']:.4f})")
|
1043
|
+
|
1044
|
+
if 'avg_latency_ms' in df.columns and not df['avg_latency_ms'].isna().all():
|
1045
|
+
valid_latency_df = df[df['avg_latency_ms'].notna() & (df['avg_latency_ms'] > 0)]
|
1046
|
+
if not valid_latency_df.empty:
|
1047
|
+
fastest = valid_latency_df.loc[valid_latency_df['avg_latency_ms'].idxmin()]
|
1048
|
+
print(f"🚀 Fastest: {fastest['model']} ({fastest['avg_latency_ms']:.0f}ms avg)")
|
1049
|
+
|
1050
|
+
except Exception as e:
|
1051
|
+
if not quiet:
|
1052
|
+
print(f"Failed to analyze efficiency: {e}")
|
1053
|
+
|
1054
|
+
|
1055
|
+
def print_error_analysis(db, experiment_ids: Dict[str, str], quiet: bool):
|
1056
|
+
"""Analyze error patterns and failure modes."""
|
1057
|
+
if not quiet:
|
1058
|
+
print("\n## 🔍 Error Analysis")
|
1059
|
+
print("-" * 20)
|
1060
|
+
|
1061
|
+
try:
|
1062
|
+
# Look for termination patterns by checking if episodes completed all steps
|
1063
|
+
valid_experiment_ids = [eid for eid in experiment_ids.values() if eid != "failed"]
|
1064
|
+
if not valid_experiment_ids:
|
1065
|
+
if not quiet:
|
1066
|
+
print("No error data available")
|
1067
|
+
return
|
1068
|
+
|
1069
|
+
query = """
|
1070
|
+
SELECT
|
1071
|
+
st.experiment_id,
|
1072
|
+
CASE
|
1073
|
+
WHEN st.num_timesteps < 100 THEN 'early_termination'
|
1074
|
+
WHEN st.num_timesteps >= 100 THEN 'max_steps_reached'
|
1075
|
+
ELSE 'unknown'
|
1076
|
+
END as termination_reason,
|
1077
|
+
COUNT(*) as episode_count
|
1078
|
+
FROM session_traces st
|
1079
|
+
WHERE st.experiment_id IN ({})
|
1080
|
+
GROUP BY st.experiment_id, termination_reason
|
1081
|
+
ORDER BY episode_count DESC
|
1082
|
+
""".format(','.join([f"'{eid}'" for eid in valid_experiment_ids]))
|
1083
|
+
|
1084
|
+
df = db.query_traces(query)
|
1085
|
+
|
1086
|
+
if not df.empty:
|
1087
|
+
# Map experiment IDs back to model names
|
1088
|
+
exp_to_model = {v: k for k, v in experiment_ids.items() if v != "failed"}
|
1089
|
+
df['model'] = df['experiment_id'].map(exp_to_model)
|
1090
|
+
|
1091
|
+
if not quiet:
|
1092
|
+
print("\n### Episode Termination Reasons:")
|
1093
|
+
pivot = df.pivot_table(index='termination_reason', columns='model', values='episode_count', fill_value=0)
|
1094
|
+
print(pivot.to_string())
|
1095
|
+
else:
|
1096
|
+
# Show most common termination reason
|
1097
|
+
most_common = df.groupby('termination_reason')['episode_count'].sum().sort_values(ascending=False)
|
1098
|
+
if not most_common.empty:
|
1099
|
+
print(f"🔚 Most Common Termination: {most_common.index[0]} ({most_common.iloc[0]} episodes)")
|
1100
|
+
|
1101
|
+
except Exception as e:
|
1102
|
+
if not quiet:
|
1103
|
+
print(f"Failed to analyze errors: {e}")
|
1104
|
+
|
1105
|
+
|
1106
|
+
async def main():
|
1107
|
+
"""Main execution function."""
|
1108
|
+
parser = argparse.ArgumentParser(
|
1109
|
+
description="Run Crafter rollouts for multiple models and compare performance"
|
1110
|
+
)
|
1111
|
+
parser.add_argument("--episodes", type=int, default=10, help="Episodes per model")
|
1112
|
+
parser.add_argument("--max-turns", type=int, default=100, help="Max turns per episode")
|
1113
|
+
parser.add_argument("--models", nargs="+", default=MODELS_TO_TEST,
|
1114
|
+
help="Models to test")
|
1115
|
+
parser.add_argument("--database", default=DATABASE_PATH, help="Database path")
|
1116
|
+
parser.add_argument("--service-url", default=CRAFTER_SERVICE_URL,
|
1117
|
+
help="Crafter service URL")
|
1118
|
+
parser.add_argument("--concurrent-models", action="store_true",
|
1119
|
+
help="Run models concurrently (default: sequential)")
|
1120
|
+
parser.add_argument("--max-concurrent-models", type=int, default=3,
|
1121
|
+
help="Maximum number of models to run concurrently")
|
1122
|
+
parser.add_argument("--quiet", action="store_true",
|
1123
|
+
help="Suppress verbose output and only show results")
|
1124
|
+
|
1125
|
+
args = parser.parse_args()
|
1126
|
+
|
1127
|
+
# Create configuration
|
1128
|
+
config = ExperimentConfig()
|
1129
|
+
config.num_episodes = args.episodes
|
1130
|
+
config.max_turns = args.max_turns
|
1131
|
+
config.database_path = args.database
|
1132
|
+
config.crafter_service_url = args.service_url
|
1133
|
+
config.quiet = args.quiet
|
1134
|
+
|
1135
|
+
# Suppress all noisy third-party logging if in quiet mode
|
1136
|
+
if config.quiet:
|
1137
|
+
import logging
|
1138
|
+
# Suppress HTTP and API client logging
|
1139
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
1140
|
+
logging.getLogger("google_genai").setLevel(logging.WARNING)
|
1141
|
+
logging.getLogger("google_genai.models").setLevel(logging.WARNING)
|
1142
|
+
logging.getLogger("openai").setLevel(logging.WARNING)
|
1143
|
+
logging.getLogger("anthropic").setLevel(logging.WARNING)
|
1144
|
+
# Suppress DuckDB and tracing logging
|
1145
|
+
logging.getLogger("synth_ai.tracing_v2.duckdb").setLevel(logging.WARNING)
|
1146
|
+
logging.getLogger("synth_ai.tracing_v2").setLevel(logging.WARNING)
|
1147
|
+
# Set root logger to WARNING to catch any other noise
|
1148
|
+
logging.getLogger().setLevel(logging.WARNING)
|
1149
|
+
|
1150
|
+
# Create trace directory
|
1151
|
+
os.makedirs(config.v2_trace_dir, exist_ok=True)
|
1152
|
+
|
1153
|
+
# Clear global sessions collection for fresh start
|
1154
|
+
_SESSIONS.clear()
|
1155
|
+
|
1156
|
+
if not config.quiet:
|
1157
|
+
print(f"🚀 STARTING MULTI-MODEL CRAFTER EVALUATION")
|
1158
|
+
print(f" Models: {args.models}")
|
1159
|
+
print(f" Episodes per model: {config.num_episodes}")
|
1160
|
+
print(f" Max turns per episode: {config.max_turns}")
|
1161
|
+
print(f" Database: {config.database_path}")
|
1162
|
+
print(f" Service URL: {config.crafter_service_url}")
|
1163
|
+
print(f" Concurrent models: {args.concurrent_models}")
|
1164
|
+
if args.concurrent_models:
|
1165
|
+
print(f" Max concurrent models: {args.max_concurrent_models}")
|
1166
|
+
|
1167
|
+
# Run experiments for each model
|
1168
|
+
all_results = {}
|
1169
|
+
experiment_ids = {}
|
1170
|
+
|
1171
|
+
if args.concurrent_models:
|
1172
|
+
# Run models concurrently with semaphore to limit concurrency
|
1173
|
+
semaphore = asyncio.Semaphore(args.max_concurrent_models)
|
1174
|
+
|
1175
|
+
async def run_model_with_semaphore(model_name: str):
|
1176
|
+
async with semaphore:
|
1177
|
+
try:
|
1178
|
+
return model_name, await run_experiment_for_model(model_name, config)
|
1179
|
+
except Exception as e:
|
1180
|
+
print(f"❌ Failed to run experiment for {model_name}: {e}")
|
1181
|
+
import traceback
|
1182
|
+
traceback.print_exc()
|
1183
|
+
return model_name, ("failed", [])
|
1184
|
+
|
1185
|
+
# Create tasks for all models
|
1186
|
+
model_tasks = [run_model_with_semaphore(model_name) for model_name in args.models]
|
1187
|
+
|
1188
|
+
if not config.quiet:
|
1189
|
+
print(f"🔄 Running up to {args.max_concurrent_models} models concurrently...")
|
1190
|
+
|
1191
|
+
# Run all model experiments concurrently
|
1192
|
+
with tqdm(total=len(args.models), desc="Models Completed") as pbar:
|
1193
|
+
for coro in asyncio.as_completed(model_tasks):
|
1194
|
+
model_name, (experiment_id, results) = await coro
|
1195
|
+
all_results[model_name] = results
|
1196
|
+
experiment_ids[model_name] = experiment_id
|
1197
|
+
pbar.update(1)
|
1198
|
+
else:
|
1199
|
+
# Run models sequentially (original behavior)
|
1200
|
+
for model_name in args.models:
|
1201
|
+
try:
|
1202
|
+
experiment_id, results = await run_experiment_for_model(model_name, config)
|
1203
|
+
all_results[model_name] = results
|
1204
|
+
experiment_ids[model_name] = experiment_id
|
1205
|
+
except Exception as e:
|
1206
|
+
print(f"❌ Failed to run experiment for {model_name}: {e}")
|
1207
|
+
import traceback
|
1208
|
+
traceback.print_exc()
|
1209
|
+
all_results[model_name] = []
|
1210
|
+
experiment_ids[model_name] = "failed"
|
1211
|
+
|
1212
|
+
# Now do bulk upload of all collected sessions in single transaction
|
1213
|
+
if not config.quiet:
|
1214
|
+
print("📤 Uploading all session traces to database...")
|
1215
|
+
print(f"📊 Found {len(_SESSIONS)} sessions to upload")
|
1216
|
+
|
1217
|
+
# DEBUG: Check for duplicate session IDs in our collection
|
1218
|
+
session_ids = list(_SESSIONS.keys())
|
1219
|
+
unique_ids = set(session_ids)
|
1220
|
+
if len(session_ids) != len(unique_ids):
|
1221
|
+
duplicates = len(session_ids) - len(unique_ids)
|
1222
|
+
print(f"🚨 FOUND {duplicates} DUPLICATE SESSION IDs IN COLLECTION!")
|
1223
|
+
from collections import Counter
|
1224
|
+
id_counts = Counter(session_ids)
|
1225
|
+
for session_id, count in id_counts.items():
|
1226
|
+
if count > 1:
|
1227
|
+
print(f" - {session_id}: {count} times")
|
1228
|
+
|
1229
|
+
# First check what's already in the database
|
1230
|
+
with DuckDBTraceManager(config.database_path) as db:
|
1231
|
+
existing_sessions = db.conn.execute("SELECT session_id FROM session_traces").fetchall()
|
1232
|
+
existing_ids = {row[0] for row in existing_sessions}
|
1233
|
+
print(f"🔍 Database already contains {len(existing_ids)} sessions")
|
1234
|
+
|
1235
|
+
# Check for conflicts
|
1236
|
+
conflicts = set(_SESSIONS.keys()) & existing_ids
|
1237
|
+
if conflicts:
|
1238
|
+
print(f"⚠️ Found {len(conflicts)} conflicting session IDs in database already!")
|
1239
|
+
for conflict_id in list(conflicts)[:5]: # Show first 5
|
1240
|
+
print(f" - {conflict_id}")
|
1241
|
+
# Get details about the existing session
|
1242
|
+
existing = db.conn.execute(
|
1243
|
+
"SELECT session_id, experiment_id, num_timesteps, created_at FROM session_traces WHERE session_id = ?",
|
1244
|
+
[conflict_id]
|
1245
|
+
).fetchone()
|
1246
|
+
if existing:
|
1247
|
+
print(f" Existing: session_id={existing[0]}, experiment_id={existing[1]}, timesteps={existing[2]}, created={existing[3]}")
|
1248
|
+
|
1249
|
+
# Also check for duplicates within our own collection
|
1250
|
+
session_ids = list(_SESSIONS.keys())
|
1251
|
+
unique_ids = set(session_ids)
|
1252
|
+
if len(session_ids) != len(unique_ids):
|
1253
|
+
duplicates = len(session_ids) - len(unique_ids)
|
1254
|
+
print(f"🚨 FOUND {duplicates} DUPLICATE SESSION IDs IN OUR COLLECTION!")
|
1255
|
+
from collections import Counter
|
1256
|
+
id_counts = Counter(session_ids)
|
1257
|
+
for session_id, count in id_counts.items():
|
1258
|
+
if count > 1:
|
1259
|
+
print(f" - {session_id}: {count} times")
|
1260
|
+
# Show the different experiment_ids for this session_id
|
1261
|
+
experiment_ids_for_session = [exp_id for exp_id, _ in _SESSIONS.values() if exp_id == session_id]
|
1262
|
+
print(f" Experiment IDs: {experiment_ids_for_session}")
|
1263
|
+
|
1264
|
+
if _SESSIONS:
|
1265
|
+
with DuckDBTraceManager(config.database_path) as db:
|
1266
|
+
uploaded_count = 0
|
1267
|
+
skipped_count = 0
|
1268
|
+
|
1269
|
+
# Process each session individually to get better error reporting
|
1270
|
+
for session_id, (experiment_id, trace) in _SESSIONS.items():
|
1271
|
+
try:
|
1272
|
+
# Check if session already exists in database
|
1273
|
+
existing = db.conn.execute(
|
1274
|
+
"SELECT session_id, experiment_id, num_timesteps FROM session_traces WHERE session_id = ?",
|
1275
|
+
[session_id]
|
1276
|
+
).fetchone()
|
1277
|
+
|
1278
|
+
if existing:
|
1279
|
+
print(f"🔍 SESSION ALREADY EXISTS: {session_id}")
|
1280
|
+
print(f" Existing: session_id={existing[0]}, experiment_id={existing[1]}, timesteps={existing[2]}")
|
1281
|
+
print(f" New: experiment_id={experiment_id}, timesteps={trace.num_timesteps if hasattr(trace, 'num_timesteps') else 'unknown'}")
|
1282
|
+
print(f" Trace object type: {type(trace)}")
|
1283
|
+
print(f" Trace object keys: {trace.__dict__.keys() if hasattr(trace, '__dict__') else 'no __dict__'}")
|
1284
|
+
skipped_count += 1
|
1285
|
+
continue
|
1286
|
+
|
1287
|
+
# Insert session (ON CONFLICT DO NOTHING handles duplicates)
|
1288
|
+
db.insert_session_trace(trace)
|
1289
|
+
# Update experiment_id
|
1290
|
+
db.conn.execute(
|
1291
|
+
"UPDATE session_traces SET experiment_id = ? "
|
1292
|
+
"WHERE session_id = ? AND (experiment_id IS NULL OR experiment_id = '')",
|
1293
|
+
[experiment_id, session_id]
|
1294
|
+
)
|
1295
|
+
uploaded_count += 1
|
1296
|
+
except AssertionError as e:
|
1297
|
+
# Re-raise assertions to debug the issue
|
1298
|
+
print(f"🚨 ASSERTION ERROR for {session_id}: {e}")
|
1299
|
+
print(f" Trace object: {trace}")
|
1300
|
+
print(f" Trace type: {type(trace)}")
|
1301
|
+
if hasattr(trace, '__dict__'):
|
1302
|
+
print(f" Trace attributes: {trace.__dict__}")
|
1303
|
+
raise
|
1304
|
+
except Exception as e:
|
1305
|
+
print(f"⚠️ Skipped {session_id}: {e}")
|
1306
|
+
print(f" Error type: {type(e).__name__}")
|
1307
|
+
print(f" Trace object type: {type(trace)}")
|
1308
|
+
if hasattr(trace, '__dict__'):
|
1309
|
+
print(f" Trace keys: {list(trace.__dict__.keys())}")
|
1310
|
+
skipped_count += 1
|
1311
|
+
|
1312
|
+
if not config.quiet:
|
1313
|
+
print(f"✅ Uploaded {uploaded_count}/{len(_SESSIONS)} sessions to database")
|
1314
|
+
if skipped_count > 0:
|
1315
|
+
print(f"⚠️ Skipped {skipped_count} sessions due to errors")
|
1316
|
+
else:
|
1317
|
+
if not config.quiet:
|
1318
|
+
print("⚠️ No sessions to upload")
|
1319
|
+
|
1320
|
+
# Print comprehensive results
|
1321
|
+
if not config.quiet:
|
1322
|
+
print_results_summary(all_results, experiment_ids)
|
1323
|
+
|
1324
|
+
# Generate comprehensive DuckDB analytics
|
1325
|
+
print_comprehensive_model_analytics(config.database_path, experiment_ids, config.quiet)
|
1326
|
+
|
1327
|
+
# Save results to file
|
1328
|
+
results_file = Path(config.v2_trace_dir) / f"multi_model_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
1329
|
+
with open(results_file, "w") as f:
|
1330
|
+
json.dump({
|
1331
|
+
"experiment_ids": experiment_ids,
|
1332
|
+
"all_results": all_results,
|
1333
|
+
"config": {
|
1334
|
+
"models": args.models,
|
1335
|
+
"episodes": config.num_episodes,
|
1336
|
+
"max_turns": config.max_turns,
|
1337
|
+
"timestamp": datetime.now().isoformat()
|
1338
|
+
},
|
1339
|
+
"analysis": {
|
1340
|
+
"performance": analyze_performance_metrics(all_results),
|
1341
|
+
"invalid_actions": analyze_invalid_actions(all_results),
|
1342
|
+
"achievements": analyze_achievements_by_step(all_results)
|
1343
|
+
}
|
1344
|
+
}, f, indent=2)
|
1345
|
+
|
1346
|
+
if not config.quiet:
|
1347
|
+
print(f"\n💾 Results saved to {results_file}")
|
1348
|
+
print(f"📊 Database available at {config.database_path}")
|
1349
|
+
|
1350
|
+
|
1351
|
+
if __name__ == "__main__":
|
1352
|
+
asyncio.run(main())
|