synth-ai 0.2.2.dev0__py3-none-any.whl → 0.2.4.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/cli/__init__.py +66 -0
- synth_ai/cli/balance.py +205 -0
- synth_ai/cli/calc.py +70 -0
- synth_ai/cli/demo.py +74 -0
- synth_ai/{cli.py → cli/legacy_root_backup.py} +60 -15
- synth_ai/cli/man.py +103 -0
- synth_ai/cli/recent.py +126 -0
- synth_ai/cli/root.py +184 -0
- synth_ai/cli/status.py +126 -0
- synth_ai/cli/traces.py +136 -0
- synth_ai/cli/watch.py +508 -0
- synth_ai/config/base_url.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +252 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_duckdb_v2_backup.py +413 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +760 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_synth.py +34 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth.py +1740 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth_v2_backup.py +1318 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_duckdb_v2_backup.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v2_backup.py +1352 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +4 -4
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/test_crafter_react_agent_openai_v2_backup.py +2551 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +1 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/old/traces/session_crafter_episode_16_15227b68-2906-416f-acc4-d6a9b4fa5828_20250725_001154.json +1363 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +3 -3
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/crafter_custom/environment.py +1 -1
- synth_ai/environments/examples/enron/dataset/corbt___enron_emails_sample_questions/default/0.0.0/293c9fe8170037e01cc9cf5834e0cd5ef6f1a6bb/dataset_info.json +1 -0
- synth_ai/environments/examples/nethack/helpers/achievements.json +64 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +1 -1
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +5 -5
- synth_ai/environments/examples/red/units/test_movement_debug.py +2 -2
- synth_ai/environments/examples/red/units/test_retry_movement.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/available_envs.json +122 -0
- synth_ai/environments/examples/sokoban/verified_puzzles.json +54987 -0
- synth_ai/environments/service/core_routes.py +1 -1
- synth_ai/experimental/synth_oss.py +446 -0
- synth_ai/learning/core.py +21 -0
- synth_ai/learning/gateway.py +4 -0
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/mipro.py +8 -0
- synth_ai/lm/__init__.py +3 -0
- synth_ai/lm/core/main.py +4 -0
- synth_ai/lm/core/main_v3.py +238 -122
- synth_ai/lm/core/vendor_clients.py +4 -0
- synth_ai/lm/provider_support/openai.py +11 -2
- synth_ai/lm/vendors/base.py +7 -0
- synth_ai/lm/vendors/openai_standard.py +339 -4
- synth_ai/lm/vendors/openai_standard_responses.py +243 -0
- synth_ai/lm/vendors/synth_client.py +155 -5
- synth_ai/lm/warmup.py +54 -17
- synth_ai/tracing/__init__.py +18 -0
- synth_ai/tracing_v1/__init__.py +29 -14
- synth_ai/tracing_v3/__init__.py +2 -2
- synth_ai/tracing_v3/abstractions.py +62 -17
- synth_ai/tracing_v3/config.py +13 -7
- synth_ai/tracing_v3/db_config.py +6 -6
- synth_ai/tracing_v3/hooks.py +1 -1
- synth_ai/tracing_v3/llm_call_record_helpers.py +350 -0
- synth_ai/tracing_v3/lm_call_record_abstractions.py +257 -0
- synth_ai/tracing_v3/session_tracer.py +5 -5
- synth_ai/tracing_v3/tests/test_concurrent_operations.py +1 -1
- synth_ai/tracing_v3/tests/test_llm_call_records.py +672 -0
- synth_ai/tracing_v3/tests/test_session_tracer.py +43 -9
- synth_ai/tracing_v3/tests/test_turso_manager.py +1 -1
- synth_ai/tracing_v3/turso/manager.py +18 -11
- synth_ai/tracing_v3/turso/models.py +1 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/dashboard.py +329 -0
- synth_ai/v0/tracing/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/base_client.py +3 -3
- synth_ai/{tracing → v0/tracing}/client_manager.py +1 -1
- synth_ai/{tracing → v0/tracing}/context.py +1 -1
- synth_ai/{tracing → v0/tracing}/decorators.py +11 -11
- synth_ai/v0/tracing/events/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/events/manage.py +4 -4
- synth_ai/{tracing → v0/tracing}/events/scope.py +6 -6
- synth_ai/{tracing → v0/tracing}/events/store.py +3 -3
- synth_ai/{tracing → v0/tracing}/immediate_client.py +6 -6
- synth_ai/{tracing → v0/tracing}/log_client_base.py +2 -2
- synth_ai/{tracing → v0/tracing}/retry_queue.py +3 -3
- synth_ai/{tracing → v0/tracing}/trackers.py +2 -2
- synth_ai/{tracing → v0/tracing}/upload.py +4 -4
- synth_ai/v0/tracing_v1/__init__.py +16 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/base_client.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/client_manager.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/context.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/decorators.py +11 -11
- synth_ai/v0/tracing_v1/events/__init__.py +0 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/manage.py +4 -4
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/scope.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/store.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/immediate_client.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/log_client_base.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/retry_queue.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/trackers.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/upload.py +4 -4
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/METADATA +100 -5
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/RECORD +115 -75
- /synth_ai/{tracing/events/__init__.py → compound/cais.py} +0 -0
- /synth_ai/{tracing_v1/events/__init__.py → environments/examples/crafter_classic/debug_translation.py} +0 -0
- /synth_ai/{tracing → v0/tracing}/abstractions.py +0 -0
- /synth_ai/{tracing → v0/tracing}/config.py +0 -0
- /synth_ai/{tracing → v0/tracing}/local.py +0 -0
- /synth_ai/{tracing → v0/tracing}/utils.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/abstractions.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/config.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/local.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/utils.py +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2551 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Test script to run ReAct agents against Crafter environment on synth service (port 8901)
|
4
|
+
Tests on multiple easy Crafter instances with enhanced debugging
|
5
|
+
"""
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import json
|
9
|
+
import uuid
|
10
|
+
import math
|
11
|
+
import argparse
|
12
|
+
import toml
|
13
|
+
import logging
|
14
|
+
import time
|
15
|
+
import functools
|
16
|
+
from datetime import datetime
|
17
|
+
from typing import Dict, Any, Optional, List, Set
|
18
|
+
from pydantic import BaseModel, Field
|
19
|
+
from httpx import AsyncClient
|
20
|
+
import httpx
|
21
|
+
import sys
|
22
|
+
import os
|
23
|
+
from pathlib import Path
|
24
|
+
from tqdm.asyncio import tqdm_asyncio
|
25
|
+
from tqdm import tqdm
|
26
|
+
import random
|
27
|
+
from collections import defaultdict
|
28
|
+
|
29
|
+
# Disable Langfuse completely to prevent hangs and warnings
|
30
|
+
os.environ["LANGFUSE_ENABLED"] = "false"
|
31
|
+
os.environ["LANGFUSE_PUBLIC_KEY"] = "dummy" # Prevent the warning about missing key
|
32
|
+
os.environ["LANGFUSE_SECRET_KEY"] = "dummy" # Prevent the secret key warning
|
33
|
+
# Disable all Langfuse logging
|
34
|
+
import logging
|
35
|
+
logging.getLogger("langfuse").setLevel(logging.ERROR)
|
36
|
+
|
37
|
+
# Monkey patch Langfuse to disable all warnings
|
38
|
+
import warnings
|
39
|
+
warnings.filterwarnings("ignore", message=".*Langfuse.*")
|
40
|
+
|
41
|
+
from langfuse.openai import openai
|
42
|
+
from langfuse import Langfuse
|
43
|
+
|
44
|
+
# Override Langfuse client to silence it completely
|
45
|
+
class SilentLangfuse(Langfuse):
|
46
|
+
def __init__(self, *args, **kwargs):
|
47
|
+
# Set dummy values to prevent warnings
|
48
|
+
kwargs['public_key'] = kwargs.get('public_key', 'dummy')
|
49
|
+
kwargs['secret_key'] = kwargs.get('secret_key', 'dummy')
|
50
|
+
kwargs['enabled'] = False
|
51
|
+
super().__init__(*args, **kwargs)
|
52
|
+
|
53
|
+
# Replace Langfuse with silent version
|
54
|
+
import langfuse
|
55
|
+
langfuse.Langfuse = SilentLangfuse
|
56
|
+
|
57
|
+
# --- Prevent Langfuse background threads from blocking shutdown ---
|
58
|
+
try:
|
59
|
+
import langfuse._task_manager.task_manager as _lftm
|
60
|
+
# Override methods that try to join background threads during interpreter shutdown
|
61
|
+
_lftm.TaskManager.shutdown = lambda self: None # type: ignore[attr-defined]
|
62
|
+
_lftm.TaskManager.join = lambda self, *a, **k: None # type: ignore[attr-defined]
|
63
|
+
import langfuse.prompt_cache as _lfpc
|
64
|
+
_lfpc.PromptCacheTaskManager.shutdown = lambda self: None # type: ignore[attr-defined]
|
65
|
+
_lfpc.PromptCacheTaskManager.join = lambda self, *a, **k: None # type: ignore[attr-defined]
|
66
|
+
except Exception:
|
67
|
+
# If internals change or Langfuse not present, proceed without hard failure.
|
68
|
+
pass
|
69
|
+
|
70
|
+
import numpy as np
|
71
|
+
|
72
|
+
# Import session tracer for CAIS event capture
|
73
|
+
from synth_ai.tracing_v2.session_tracer import (
|
74
|
+
SessionTracer, SessionEventMarkovBlanketMessage, TimeRecord,
|
75
|
+
RuntimeEvent, EnvironmentEvent
|
76
|
+
)
|
77
|
+
from synth_ai.tracing_v2.abstractions import CAISEvent
|
78
|
+
from synth_ai.tracing_v2.utils import create_experiment_context
|
79
|
+
from synth_ai.tracing_v2.duckdb.manager import DuckDBTraceManager
|
80
|
+
from datetime import datetime
|
81
|
+
|
82
|
+
# Retry configuration for HTTP requests
|
83
|
+
MAX_RETRIES = 3 # Increase to 3 retries for better reliability
|
84
|
+
BASE_DELAY = 0.1 # 100ms base delay
|
85
|
+
MAX_DELAY = 2.0 # Max 2 seconds delay
|
86
|
+
HTTP_TIMEOUT = 10.0 # 10 seconds timeout for slower connections
|
87
|
+
|
88
|
+
async def retry_http_request(client: AsyncClient, method: str, url: str, **kwargs) -> Any:
|
89
|
+
"""
|
90
|
+
Retry HTTP requests with exponential backoff and jitter.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
client: httpx AsyncClient
|
94
|
+
method: HTTP method ('GET', 'POST', etc.)
|
95
|
+
url: Request URL
|
96
|
+
**kwargs: Additional arguments for the request
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
Response object
|
100
|
+
|
101
|
+
Raises:
|
102
|
+
Exception: If all retries fail
|
103
|
+
"""
|
104
|
+
last_exception = None
|
105
|
+
|
106
|
+
for attempt in range(MAX_RETRIES):
|
107
|
+
try:
|
108
|
+
# Calculate delay with exponential backoff and jitter
|
109
|
+
if attempt > 0:
|
110
|
+
delay = min(BASE_DELAY * (2 ** (attempt - 1)), MAX_DELAY)
|
111
|
+
jitter = random.uniform(0, 0.1 * delay) # 10% jitter
|
112
|
+
total_delay = delay + jitter
|
113
|
+
# Don't print retry messages - only print if all retries fail
|
114
|
+
await asyncio.sleep(total_delay)
|
115
|
+
|
116
|
+
# Make the request with timeout
|
117
|
+
start_request = time.time()
|
118
|
+
response = await client.request(method, url, timeout=HTTP_TIMEOUT, **kwargs)
|
119
|
+
end_request = time.time()
|
120
|
+
|
121
|
+
# Check if response is successful
|
122
|
+
if response.status_code < 500: # Don't retry client errors (4xx)
|
123
|
+
return response
|
124
|
+
|
125
|
+
# For server errors (5xx), continue retrying
|
126
|
+
last_exception = Exception(f"HTTP {response.status_code}: {response.text}")
|
127
|
+
|
128
|
+
except httpx.ReadError as e:
|
129
|
+
# Specific handling for ReadErrors (connection issues)
|
130
|
+
last_exception = e
|
131
|
+
# For ReadErrors, wait longer with exponential backoff
|
132
|
+
if attempt < MAX_RETRIES - 1:
|
133
|
+
read_error_delay = min(1.0 * (2 ** attempt), 5.0) # 1s, 2s, 4s (max 5s)
|
134
|
+
await asyncio.sleep(read_error_delay)
|
135
|
+
except Exception as e:
|
136
|
+
last_exception = e
|
137
|
+
# Don't log intermediate failures - only final failure
|
138
|
+
|
139
|
+
# All retries failed
|
140
|
+
print(f" ❌ HTTP request failed after {MAX_RETRIES} attempts: {type(last_exception).__name__}: {str(last_exception)[:200]}")
|
141
|
+
raise last_exception
|
142
|
+
|
143
|
+
# Import Crafter hooks
|
144
|
+
try:
|
145
|
+
from synth_ai.environments.examples.crafter_classic.trace_hooks import CRAFTER_HOOKS
|
146
|
+
print(f"✅ Loaded {len(CRAFTER_HOOKS)} Crafter achievement hooks (Easy, Medium, Hard)")
|
147
|
+
except ImportError:
|
148
|
+
print("Warning: Could not import CRAFTER_HOOKS")
|
149
|
+
CRAFTER_HOOKS = []
|
150
|
+
|
151
|
+
|
152
|
+
# Create a proper message structure with origin_system_id
|
153
|
+
def create_message(content: Any, message_type: str, origin_system_id: Any, turn: int) -> SessionEventMarkovBlanketMessage:
|
154
|
+
"""Create a message with origin system ID embedded in content."""
|
155
|
+
return SessionEventMarkovBlanketMessage(
|
156
|
+
content={
|
157
|
+
"origin_system_id": str(origin_system_id),
|
158
|
+
"payload": content
|
159
|
+
},
|
160
|
+
message_type=message_type,
|
161
|
+
time_record=TimeRecord(
|
162
|
+
event_time=datetime.now().isoformat(),
|
163
|
+
message_time=turn
|
164
|
+
)
|
165
|
+
)
|
166
|
+
|
167
|
+
|
168
|
+
def compress_observation_for_trace(obs: Dict[str, Any]) -> Dict[str, Any]:
|
169
|
+
"""Compress observation data for efficient trace storage."""
|
170
|
+
import base64
|
171
|
+
obs_compressed = obs.copy()
|
172
|
+
|
173
|
+
# Convert semantic map to text
|
174
|
+
if "semantic_map" in obs_compressed:
|
175
|
+
map_view = format_semantic_map_view(obs_compressed, view_size=7)
|
176
|
+
obs_compressed["semantic_map_text"] = map_view
|
177
|
+
del obs_compressed["semantic_map"]
|
178
|
+
|
179
|
+
# Skip heavy fields instead of base64 encoding - just store shape/hash
|
180
|
+
heavy_fields = ["observation_image", "world_material_map", "rgb", "image"]
|
181
|
+
for field in heavy_fields:
|
182
|
+
if field in obs_compressed and isinstance(obs_compressed[field], (list, np.ndarray)):
|
183
|
+
arr = np.array(obs_compressed[field], dtype=np.uint8)
|
184
|
+
# Just store metadata instead of full data
|
185
|
+
obs_compressed[f"{field}_shape"] = arr.shape
|
186
|
+
obs_compressed[f"{field}_size_kb"] = arr.nbytes / 1024
|
187
|
+
obs_compressed[f"{field}_hash"] = hash(arr.tobytes()) % 1000000 # Simple hash for tracking
|
188
|
+
del obs_compressed[field]
|
189
|
+
|
190
|
+
return obs_compressed
|
191
|
+
|
192
|
+
|
193
|
+
def print_hook_legend():
|
194
|
+
"""Print the legend for hook codes."""
|
195
|
+
print("\n📖 Hook Legend:")
|
196
|
+
print(" E = Easy achievement (e.g., collect_wood, place_table)")
|
197
|
+
print(" M = Medium achievement (e.g., make_wood_pickaxe, collect_coal)")
|
198
|
+
print(" H = Hard achievement (e.g., make_iron_sword, defeat_zombie)")
|
199
|
+
print(" X = Invalid action (action had no effect)")
|
200
|
+
print(" # = Regular step")
|
201
|
+
print("") # Add blank line to separate from progress bars
|
202
|
+
|
203
|
+
|
204
|
+
# NOTE: These custom progress display functions are no longer used - replaced with tqdm
|
205
|
+
# def create_progress_bar(episode_num: int, steps: List[str], max_steps: int) -> str:
|
206
|
+
# """Create a progress bar string with hook codes."""
|
207
|
+
# # Pad with spaces if fewer steps than max
|
208
|
+
# padded_steps = steps + [' '] * (max_steps - len(steps))
|
209
|
+
# bar = ''.join(padded_steps[:max_steps])
|
210
|
+
# return f"Episode {episode_num:2d}: [{bar}] {len(steps)}/{max_steps}"
|
211
|
+
|
212
|
+
|
213
|
+
# def update_progress_display(episode_bars: Dict[int, List[str]], max_steps: int):
|
214
|
+
# """Update the progress display in place."""
|
215
|
+
# # Clear previous lines
|
216
|
+
# num_episodes = len(episode_bars)
|
217
|
+
# if num_episodes > 0:
|
218
|
+
# # Move cursor up to overwrite previous display
|
219
|
+
# print(f"\033[{num_episodes}A", end='')
|
220
|
+
#
|
221
|
+
# # Print all episode progress bars
|
222
|
+
# for episode_num in sorted(episode_bars.keys()):
|
223
|
+
# steps = episode_bars[episode_num]
|
224
|
+
# print(create_progress_bar(episode_num, steps, max_steps))
|
225
|
+
#
|
226
|
+
# # Ensure we don't leave cursor in wrong position
|
227
|
+
# sys.stdout.flush()
|
228
|
+
|
229
|
+
def clear_progress_display():
|
230
|
+
"""Clear the progress display area to prevent overlap with error messages."""
|
231
|
+
print("\n" * 3) # Add extra spacing
|
232
|
+
|
233
|
+
|
234
|
+
def print_achievements_table(all_achievements: Dict[str, int], num_episodes: int):
|
235
|
+
"""Print a beautiful table of achievements across all episodes."""
|
236
|
+
if not all_achievements:
|
237
|
+
return
|
238
|
+
|
239
|
+
print("\n" + "=" * 80)
|
240
|
+
print("🏆 ACHIEVEMENTS SUMMARY")
|
241
|
+
print("=" * 80)
|
242
|
+
print(f"{'Achievement':<30} {'Count':<10} {'Percentage':<15}")
|
243
|
+
print("-" * 55)
|
244
|
+
|
245
|
+
# Sort achievements by count (descending) then by name
|
246
|
+
sorted_achievements = sorted(all_achievements.items(), key=lambda x: (-x[1], x[0]))
|
247
|
+
|
248
|
+
for achievement, count in sorted_achievements:
|
249
|
+
percentage = (count / num_episodes) * 100
|
250
|
+
print(f"{achievement:<30} {count:<10} {percentage:>6.1f}%")
|
251
|
+
|
252
|
+
print("-" * 55)
|
253
|
+
print(f"{'TOTAL UNIQUE':<30} {len(all_achievements):<10}")
|
254
|
+
print("=" * 80)
|
255
|
+
|
256
|
+
|
257
|
+
def print_invalid_actions_table(invalid_actions: Dict[str, int], total_actions: Dict[str, int] = None):
|
258
|
+
"""Print a table of invalid actions by type with failure rates."""
|
259
|
+
if not invalid_actions:
|
260
|
+
return
|
261
|
+
|
262
|
+
print("\n" + "=" * 90)
|
263
|
+
print("❌ INVALID ACTIONS SUMMARY")
|
264
|
+
print("=" * 90)
|
265
|
+
print(f"{'Action Type':<20} {'Invalid/Total':<15} {'Failure %':<12} {'Description':<35}")
|
266
|
+
print("-" * 90)
|
267
|
+
|
268
|
+
# Sort by count (descending) then by name
|
269
|
+
sorted_actions = sorted(invalid_actions.items(), key=lambda x: (-x[1], x[0]))
|
270
|
+
|
271
|
+
action_descriptions = {
|
272
|
+
'move_left': 'Movement blocked (wall/edge)',
|
273
|
+
'move_right': 'Movement blocked (wall/edge)',
|
274
|
+
'move_up': 'Movement blocked (wall/edge)',
|
275
|
+
'move_down': 'Movement blocked (wall/edge)',
|
276
|
+
'do': 'Nothing to collect/attack',
|
277
|
+
'sleep': 'Energy already full or conditions not met',
|
278
|
+
'place_stone': 'No stone or invalid location',
|
279
|
+
'place_table': 'No wood or invalid location',
|
280
|
+
'place_furnace': 'No stone or invalid location',
|
281
|
+
'place_plant': 'No sapling or invalid location',
|
282
|
+
'make_wood_pickaxe': 'Missing materials or no table',
|
283
|
+
'make_stone_pickaxe': 'Missing materials or no table',
|
284
|
+
'make_iron_pickaxe': 'Missing materials or no furnace',
|
285
|
+
'make_wood_sword': 'Missing materials or no table',
|
286
|
+
'make_stone_sword': 'Missing materials or no table',
|
287
|
+
'make_iron_sword': 'Missing materials or no furnace'
|
288
|
+
}
|
289
|
+
|
290
|
+
total_invalid = 0
|
291
|
+
for action, invalid_count in sorted_actions:
|
292
|
+
total_count = total_actions.get(action, invalid_count) if total_actions else invalid_count
|
293
|
+
failure_rate = (invalid_count / total_count * 100) if total_count > 0 else 0
|
294
|
+
fraction = f"{invalid_count}/{total_count}"
|
295
|
+
description = action_descriptions.get(action, 'Unknown reason')
|
296
|
+
print(f"{action:<20} {fraction:<15} {failure_rate:>6.1f}% {description:<35}")
|
297
|
+
total_invalid += invalid_count
|
298
|
+
|
299
|
+
print("-" * 90)
|
300
|
+
total_all_actions = sum(total_actions.values()) if total_actions else total_invalid
|
301
|
+
total_failure_rate = (total_invalid / total_all_actions * 100) if total_all_actions > 0 else 0
|
302
|
+
print(f"{'TOTAL':<20} {total_invalid}/{total_all_actions:<15} {total_failure_rate:>6.1f}%")
|
303
|
+
print("=" * 90)
|
304
|
+
|
305
|
+
|
306
|
+
def print_termination_breakdown(termination_reasons: List[str]):
|
307
|
+
"""Print episode termination breakdown."""
|
308
|
+
print("\n" + "=" * 80)
|
309
|
+
print("🏁 EPISODE TERMINATION BREAKDOWN")
|
310
|
+
print("=" * 80)
|
311
|
+
|
312
|
+
if not termination_reasons:
|
313
|
+
print("No termination data available.")
|
314
|
+
return
|
315
|
+
|
316
|
+
# Count termination reasons
|
317
|
+
reason_counts = {}
|
318
|
+
for reason in termination_reasons:
|
319
|
+
reason_counts[reason] = reason_counts.get(reason, 0) + 1
|
320
|
+
|
321
|
+
# Sort by count descending
|
322
|
+
sorted_reasons = sorted(reason_counts.items(), key=lambda x: x[1], reverse=True)
|
323
|
+
|
324
|
+
total_episodes = len(termination_reasons)
|
325
|
+
|
326
|
+
print(f"{'Termination Reason':<40} {'Count':<10} {'Percentage':<12} {'Description'}")
|
327
|
+
print("-" * 80)
|
328
|
+
|
329
|
+
# Descriptions for different termination types
|
330
|
+
descriptions = {
|
331
|
+
"max_turns_reached": "Episode completed all turns",
|
332
|
+
"death": "Agent died (health <= 0)",
|
333
|
+
"environment_terminated": "Environment ended episode",
|
334
|
+
"no_actions_provided": "Agent failed to provide actions"
|
335
|
+
}
|
336
|
+
|
337
|
+
for reason, count in sorted_reasons:
|
338
|
+
percentage = (count / total_episodes * 100) if total_episodes > 0 else 0
|
339
|
+
|
340
|
+
# Parse complex reasons
|
341
|
+
display_reason = reason
|
342
|
+
description = "Other termination reason"
|
343
|
+
|
344
|
+
if reason == "max_turns_reached":
|
345
|
+
description = descriptions.get(reason, "Episode completed all turns")
|
346
|
+
elif reason == "death":
|
347
|
+
description = descriptions.get(reason, "Agent died (health <= 0)")
|
348
|
+
elif reason == "environment_terminated":
|
349
|
+
description = descriptions.get(reason, "Environment ended episode")
|
350
|
+
elif reason.startswith("agent_terminate:"):
|
351
|
+
display_reason = "agent_terminate"
|
352
|
+
description = f"Agent chose to quit: {reason.split(':', 1)[1][:30]}"
|
353
|
+
elif reason.startswith("http_error:"):
|
354
|
+
display_reason = "http_error"
|
355
|
+
description = f"API request failed: {reason.split(':', 1)[1]}"
|
356
|
+
elif reason.startswith("exception:"):
|
357
|
+
display_reason = "exception"
|
358
|
+
error_detail = reason.split(':', 1)[1].strip()
|
359
|
+
if error_detail:
|
360
|
+
description = f"Runtime error: {error_detail[:40]}"
|
361
|
+
else:
|
362
|
+
description = "Runtime error (unknown cause)"
|
363
|
+
elif reason.startswith("outer_exception:"):
|
364
|
+
display_reason = "outer_exception"
|
365
|
+
error_detail = reason.split(':', 1)[1].strip()
|
366
|
+
if error_detail:
|
367
|
+
description = f"Fatal error: {error_detail[:40]}"
|
368
|
+
else:
|
369
|
+
description = "Fatal error (unknown cause)"
|
370
|
+
elif reason == "no_actions_provided":
|
371
|
+
description = descriptions.get(reason, "Agent failed to provide actions")
|
372
|
+
|
373
|
+
print(f"{display_reason:<40} {count:<10} {percentage:<11.1f}% {description}")
|
374
|
+
|
375
|
+
print("-" * 80)
|
376
|
+
print(f"{'TOTAL':<40} {total_episodes:<10} {'100.0%':<11}")
|
377
|
+
print("=" * 80)
|
378
|
+
|
379
|
+
|
380
|
+
def print_timing_analysis(results: List[Dict[str, Any]]):
|
381
|
+
"""Print comprehensive timing analysis."""
|
382
|
+
if not results:
|
383
|
+
return
|
384
|
+
|
385
|
+
# Extract timing data from valid results
|
386
|
+
episode_times = []
|
387
|
+
all_step_times = []
|
388
|
+
all_env_times = []
|
389
|
+
all_agent_times = []
|
390
|
+
|
391
|
+
for result in results:
|
392
|
+
if not result.get("error", False) and "timing" in result:
|
393
|
+
timing = result["timing"]
|
394
|
+
episode_times.append(timing["episode_total_time"])
|
395
|
+
all_step_times.extend(timing["step_times"])
|
396
|
+
all_env_times.extend(timing["env_times"])
|
397
|
+
all_agent_times.extend(timing["agent_times"])
|
398
|
+
|
399
|
+
if not episode_times:
|
400
|
+
print("⚠️ No timing data available for analysis")
|
401
|
+
return
|
402
|
+
|
403
|
+
print("=" * 80)
|
404
|
+
print("⏱️ TIMING ANALYSIS")
|
405
|
+
print("=" * 80)
|
406
|
+
|
407
|
+
# Episode-level timing
|
408
|
+
print("📊 EPISODE TIMING DISTRIBUTION")
|
409
|
+
print("-" * 40)
|
410
|
+
episode_times.sort()
|
411
|
+
print(f"Total Episodes: {len(episode_times)}")
|
412
|
+
print(f"Mean Episode Time: {sum(episode_times)/len(episode_times):.2f}s")
|
413
|
+
print(f"Median Episode Time: {episode_times[len(episode_times)//2]:.2f}s")
|
414
|
+
print(f"Min Episode Time: {min(episode_times):.2f}s")
|
415
|
+
print(f"Max Episode Time: {max(episode_times):.2f}s")
|
416
|
+
print(f"P95 Episode Time: {episode_times[int(len(episode_times)*0.95)]:.2f}s")
|
417
|
+
print()
|
418
|
+
|
419
|
+
# Step-level timing
|
420
|
+
if all_step_times:
|
421
|
+
print("📊 STEP TIMING DISTRIBUTION")
|
422
|
+
print("-" * 40)
|
423
|
+
all_step_times.sort()
|
424
|
+
print(f"Total Steps: {len(all_step_times)}")
|
425
|
+
print(f"Mean Step Time: {sum(all_step_times)/len(all_step_times):.2f}s")
|
426
|
+
print(f"Median Step Time: {all_step_times[len(all_step_times)//2]:.2f}s")
|
427
|
+
print(f"Min Step Time: {min(all_step_times):.2f}s")
|
428
|
+
print(f"Max Step Time: {max(all_step_times):.2f}s")
|
429
|
+
print(f"P95 Step Time: {all_step_times[int(len(all_step_times)*0.95)]:.2f}s")
|
430
|
+
print()
|
431
|
+
|
432
|
+
# Environment vs Agent timing
|
433
|
+
if all_env_times and all_agent_times:
|
434
|
+
print("📊 ENVIRONMENT vs AGENT TIMING")
|
435
|
+
print("-" * 40)
|
436
|
+
|
437
|
+
env_mean = sum(all_env_times)/len(all_env_times)
|
438
|
+
agent_mean = sum(all_agent_times)/len(all_agent_times)
|
439
|
+
|
440
|
+
print(f"Environment Calls: {len(all_env_times)}")
|
441
|
+
print(f"Agent Calls: {len(all_agent_times)}")
|
442
|
+
print(f"Mean Environment Time: {env_mean:.2f}s")
|
443
|
+
print(f"Mean Agent Time: {agent_mean:.2f}s")
|
444
|
+
print(f"Environment/Agent Ratio: {env_mean/agent_mean:.2f}x")
|
445
|
+
|
446
|
+
# Time breakdown
|
447
|
+
total_env_time = sum(all_env_times)
|
448
|
+
total_agent_time = sum(all_agent_times)
|
449
|
+
total_time = total_env_time + total_agent_time
|
450
|
+
|
451
|
+
if total_time > 0:
|
452
|
+
print(f"Environment %: {(total_env_time/total_time)*100:.1f}%")
|
453
|
+
print(f"Agent %: {(total_agent_time/total_time)*100:.1f}%")
|
454
|
+
|
455
|
+
print()
|
456
|
+
|
457
|
+
# Distribution comparison
|
458
|
+
all_env_times.sort()
|
459
|
+
all_agent_times.sort()
|
460
|
+
|
461
|
+
print("Environment Time Distribution:")
|
462
|
+
print(f" P50: {all_env_times[len(all_env_times)//2]:.2f}s")
|
463
|
+
print(f" P90: {all_env_times[int(len(all_env_times)*0.9)]:.2f}s")
|
464
|
+
print(f" P95: {all_env_times[int(len(all_env_times)*0.95)]:.2f}s")
|
465
|
+
print(f" P99: {all_env_times[int(len(all_env_times)*0.99)]:.2f}s")
|
466
|
+
|
467
|
+
print("Agent Time Distribution:")
|
468
|
+
print(f" P50: {all_agent_times[len(all_agent_times)//2]:.2f}s")
|
469
|
+
print(f" P90: {all_agent_times[int(len(all_agent_times)*0.9)]:.2f}s")
|
470
|
+
print(f" P95: {all_agent_times[int(len(all_agent_times)*0.95)]:.2f}s")
|
471
|
+
print(f" P99: {all_agent_times[int(len(all_agent_times)*0.99)]:.2f}s")
|
472
|
+
|
473
|
+
print("=" * 80)
|
474
|
+
|
475
|
+
|
476
|
+
def print_condensed_summary(all_achievements: Dict[str, int], invalid_actions: Dict[str, int],
|
477
|
+
total_actions: Dict[str, int], termination_reasons: List[str],
|
478
|
+
results: List[Dict[str, Any]], num_episodes: int):
|
479
|
+
"""Print a dense, condensed summary of all metrics in a single compact table."""
|
480
|
+
print("\n" + "─" * 80)
|
481
|
+
print("CRAFTER EVALUATION SUMMARY")
|
482
|
+
print("─" * 80)
|
483
|
+
|
484
|
+
# Calculate aggregated metrics
|
485
|
+
unique_achievements = len(all_achievements)
|
486
|
+
total_achievements = sum(all_achievements.values())
|
487
|
+
total_invalid = sum(invalid_actions.values())
|
488
|
+
total_acts = sum(total_actions.values())
|
489
|
+
invalid_rate = (total_invalid / total_acts * 100) if total_acts > 0 else 0
|
490
|
+
|
491
|
+
# Timing metrics
|
492
|
+
episode_times = [r['timing_info']['episode_time'] for r in results if 'timing_info' in r and r['timing_info']]
|
493
|
+
step_times = []
|
494
|
+
for r in results:
|
495
|
+
if 'timing_info' in r and r['timing_info'] and 'step_times' in r['timing_info']:
|
496
|
+
step_times.extend(r['timing_info']['step_times'])
|
497
|
+
|
498
|
+
avg_episode_time = sum(episode_times) / len(episode_times) if episode_times else 0
|
499
|
+
avg_step_time = sum(step_times) / len(step_times) if step_times else 0
|
500
|
+
|
501
|
+
# Termination breakdown
|
502
|
+
term_counts = {}
|
503
|
+
for reason in termination_reasons:
|
504
|
+
term_counts[reason] = term_counts.get(reason, 0) + 1
|
505
|
+
|
506
|
+
# Best achievement
|
507
|
+
best_achievement = max(all_achievements.items(), key=lambda x: x[1])[0] if all_achievements else "none"
|
508
|
+
|
509
|
+
# Achievement distribution by trajectory
|
510
|
+
achievement_by_episode = {}
|
511
|
+
for r in results:
|
512
|
+
if 'num_achievements' in r:
|
513
|
+
count = r['num_achievements']
|
514
|
+
achievement_by_episode[count] = achievement_by_episode.get(count, 0) + 1
|
515
|
+
|
516
|
+
# Format achievement distribution
|
517
|
+
achv_dist_str = " | ".join([f"{k} achv: {v}" for k, v in sorted(achievement_by_episode.items())])
|
518
|
+
|
519
|
+
# Achievement frequencies by type
|
520
|
+
achv_freq_str = " | ".join([f"{k}: {v/total_achievements*100:.0f}%" for k, v in sorted(all_achievements.items(), key=lambda x: x[1], reverse=True)[:3]]) if total_achievements > 0 else ""
|
521
|
+
|
522
|
+
# Print compact table
|
523
|
+
print(f"Episodes: {num_episodes} | Achievements: {unique_achievements} ({total_achievements} total) | Invalid: {total_invalid}/{total_acts} ({invalid_rate:.1f}%)")
|
524
|
+
print(f"Avg Episode: {avg_episode_time:.1f}s | Avg Step: {avg_step_time:.1f}s | Best: {best_achievement}")
|
525
|
+
|
526
|
+
# Most common termination
|
527
|
+
if term_counts:
|
528
|
+
most_common_term = max(term_counts.items(), key=lambda x: x[1])
|
529
|
+
print(f"Termination: {most_common_term[0]} ({most_common_term[1]}/{num_episodes})")
|
530
|
+
|
531
|
+
# Achievement distributions
|
532
|
+
if achv_dist_str:
|
533
|
+
print(f"Achv by traj: {achv_dist_str}")
|
534
|
+
|
535
|
+
# Print achievement frequencies vertically
|
536
|
+
if all_achievements and total_achievements > 0:
|
537
|
+
print("\nAchievement frequencies:")
|
538
|
+
for achv, count in sorted(all_achievements.items(), key=lambda x: x[1], reverse=True):
|
539
|
+
print(f" {achv:<20} {count/total_achievements*100:>3.0f}%")
|
540
|
+
|
541
|
+
print("─" * 80)
|
542
|
+
|
543
|
+
|
544
|
+
def analyze_trace_file(trace_file: Path):
|
545
|
+
"""Analyze a trace file and print detailed step-by-step information."""
|
546
|
+
import time
|
547
|
+
|
548
|
+
start_time = time.time()
|
549
|
+
print(f"\n📄 Analyzing trace: {trace_file.name}")
|
550
|
+
print("=" * 80)
|
551
|
+
|
552
|
+
try:
|
553
|
+
with open(trace_file, 'r') as f:
|
554
|
+
trace_data = json.load(f)
|
555
|
+
except Exception as e:
|
556
|
+
print(f"[ERROR] Failed to load trace file: {e}")
|
557
|
+
return
|
558
|
+
|
559
|
+
# Get events and messages
|
560
|
+
events = trace_data.get('event_history', [])
|
561
|
+
messages = trace_data.get('message_history', [])
|
562
|
+
|
563
|
+
# Group CAISEvents by turn
|
564
|
+
cais_events = [e for e in events if e.get('system_instance_id', '').startswith('crafter-react-agent')]
|
565
|
+
|
566
|
+
# Process each turn
|
567
|
+
for i, cais_event in enumerate(cais_events):
|
568
|
+
print(f"\n🎮 Step {i + 1}")
|
569
|
+
print("-" * 40)
|
570
|
+
|
571
|
+
# Token usage
|
572
|
+
tokens = {
|
573
|
+
'prompt': cais_event.get('prompt_tokens', None),
|
574
|
+
'completion': cais_event.get('completion_tokens', None),
|
575
|
+
'total': cais_event.get('total_tokens', None)
|
576
|
+
}
|
577
|
+
if any(t is not None for t in tokens.values()):
|
578
|
+
print(f"🪙 Tokens: Prompt={tokens['prompt']}, Completion={tokens['completion']}, Total={tokens['total']}")
|
579
|
+
else:
|
580
|
+
# Try to get from llm_call_records
|
581
|
+
llm_records = cais_event.get('llm_call_records', [])
|
582
|
+
if llm_records and isinstance(llm_records[0], dict):
|
583
|
+
response = llm_records[0].get('response', {})
|
584
|
+
if response:
|
585
|
+
usage = response.get('usage', {})
|
586
|
+
if usage:
|
587
|
+
print(f"🪙 Tokens: Prompt={usage.get('prompt_tokens')}, Completion={usage.get('completion_tokens')}, Total={usage.get('total_tokens')}")
|
588
|
+
|
589
|
+
# Tool calls from LLM records
|
590
|
+
llm_records = cais_event.get('llm_call_records', [])
|
591
|
+
if llm_records and isinstance(llm_records[0], dict):
|
592
|
+
# First check the response for tool calls
|
593
|
+
response = llm_records[0].get('response', {})
|
594
|
+
if response and 'choices' in response:
|
595
|
+
choices = response.get('choices', [])
|
596
|
+
if choices and isinstance(choices[0], dict):
|
597
|
+
tool_calls = choices[0].get('message', {}).get('tool_calls', [])
|
598
|
+
if tool_calls:
|
599
|
+
for tc in tool_calls:
|
600
|
+
tool_name = tc.get('function', {}).get('name', 'unknown')
|
601
|
+
args = json.loads(tc.get('function', {}).get('arguments', '{}'))
|
602
|
+
actions = args.get('actions', [])
|
603
|
+
reasoning = args.get('reasoning', '')
|
604
|
+
print(f"🔧 Tool: {tool_name}")
|
605
|
+
if actions:
|
606
|
+
print(f" Actions: {', '.join(actions)}")
|
607
|
+
if reasoning:
|
608
|
+
print(f" Reasoning: {reasoning[:60]}...")
|
609
|
+
|
610
|
+
# Show hooks that fired for this CAISEvent
|
611
|
+
if cais_event.get('event_metadata'):
|
612
|
+
print("🎯 Agent Hooks Fired:")
|
613
|
+
for meta in cais_event['event_metadata']:
|
614
|
+
print(f" - {meta['hook_name']}: {meta['description']}")
|
615
|
+
|
616
|
+
# Find corresponding observations to track achievements and inventory
|
617
|
+
turn_time = cais_event.get('time_record', {}).get('message_time', i)
|
618
|
+
|
619
|
+
# Get observations for this turn and next turn
|
620
|
+
turn_observations = [m for m in messages
|
621
|
+
if m.get('message_type') == 'observation'
|
622
|
+
and m.get('time_record', {}).get('message_time', -1) in [turn_time, turn_time + 1]]
|
623
|
+
|
624
|
+
if len(turn_observations) >= 2:
|
625
|
+
before_obs = turn_observations[0].get('content', {}).get('payload', {})
|
626
|
+
after_obs = turn_observations[1].get('content', {}).get('payload', {})
|
627
|
+
|
628
|
+
# Achievement changes
|
629
|
+
before_achievements = before_obs.get('achievements_status', {})
|
630
|
+
after_achievements = after_obs.get('achievements_status', {})
|
631
|
+
|
632
|
+
new_achievements = []
|
633
|
+
for ach_name, ach_status in after_achievements.items():
|
634
|
+
if ach_status and not before_achievements.get(ach_name, False):
|
635
|
+
new_achievements.append(ach_name)
|
636
|
+
|
637
|
+
if new_achievements:
|
638
|
+
print(f"🏆 New Achievements: {', '.join(new_achievements)}")
|
639
|
+
|
640
|
+
# Inventory changes
|
641
|
+
before_inv = before_obs.get('inventory', {})
|
642
|
+
after_inv = after_obs.get('inventory', {})
|
643
|
+
|
644
|
+
inv_changes = []
|
645
|
+
for item, count in after_inv.items():
|
646
|
+
before_count = before_inv.get(item, 0)
|
647
|
+
if count > before_count:
|
648
|
+
inv_changes.append(f"{item}: {before_count} → {count} (+{count - before_count})")
|
649
|
+
elif count < before_count:
|
650
|
+
inv_changes.append(f"{item}: {before_count} → {count} ({count - before_count})")
|
651
|
+
|
652
|
+
if inv_changes:
|
653
|
+
print(f"📦 Inventory Changes: {', '.join(inv_changes)}")
|
654
|
+
|
655
|
+
# Position change
|
656
|
+
before_pos = before_obs.get('player_position', [0, 0])
|
657
|
+
after_pos = after_obs.get('player_position', [0, 0])
|
658
|
+
if before_pos != after_pos:
|
659
|
+
print(f"📍 Position: {before_pos} → {after_pos}")
|
660
|
+
|
661
|
+
# Find corresponding environment event and show its hooks
|
662
|
+
env_events = [e for e in events
|
663
|
+
if not e.get('system_instance_id', '').startswith('crafter-react-agent')
|
664
|
+
and e.get('time_record', {}).get('message_time', -1) == turn_time
|
665
|
+
and 'reward' in e]
|
666
|
+
|
667
|
+
if env_events and env_events[0].get('event_metadata'):
|
668
|
+
print("🌍 Environment Hooks Fired:")
|
669
|
+
for meta in env_events[0]['event_metadata']:
|
670
|
+
print(f" - {meta['hook_name']}: {meta['description']}")
|
671
|
+
|
672
|
+
print("\n" + "=" * 80)
|
673
|
+
|
674
|
+
end_time = time.time()
|
675
|
+
print(f"[DEBUG] Trace analysis completed in {end_time - start_time:.2f} seconds")
|
676
|
+
|
677
|
+
|
678
|
+
# --- Configuration Class ---
|
679
|
+
class CrafterConfig:
|
680
|
+
"""Configuration for Crafter evaluation."""
|
681
|
+
|
682
|
+
def __init__(self, config_path: Optional[str] = None):
|
683
|
+
# Default values
|
684
|
+
self.model_name: Optional[str] = None # Must be provided via config or CLI
|
685
|
+
self.num_instances = 1 # Changed from 3 to 1
|
686
|
+
self.max_turns = 2 # Changed to just 2 steps
|
687
|
+
self.difficulty = "easy"
|
688
|
+
self.service_base_url = "http://localhost:8901"
|
689
|
+
self.service_timeout = 30.0
|
690
|
+
self.seed = 42
|
691
|
+
self.save_traces = True
|
692
|
+
self.save_detailed_results = True
|
693
|
+
self.verbose = False
|
694
|
+
self.analyze_traces = False # Whether to analyze traces after evaluation
|
695
|
+
|
696
|
+
# Custom OpenAI endpoint support
|
697
|
+
self.custom_openai_base_url = None # e.g., "https://lora-inference-service-xyz.modal.run"
|
698
|
+
self.custom_openai_api_key = "dummy" # Default dummy key for custom endpoints
|
699
|
+
|
700
|
+
# Load from TOML if provided
|
701
|
+
if config_path and os.path.exists(config_path):
|
702
|
+
self.load_from_toml(config_path)
|
703
|
+
|
704
|
+
# Fail fast if no model name provided
|
705
|
+
# Configure custom OpenAI endpoint if specified
|
706
|
+
self._configure_custom_openai()
|
707
|
+
|
708
|
+
def load_from_toml(self, config_path: str):
|
709
|
+
"""Load configuration from TOML file."""
|
710
|
+
config = toml.load(config_path)
|
711
|
+
|
712
|
+
# Extract eval settings
|
713
|
+
eval_config = config.get("eval", {})
|
714
|
+
self.model_name = eval_config.get("model_name", self.model_name)
|
715
|
+
self.num_instances = eval_config.get("episodes", self.num_instances)
|
716
|
+
self.max_turns = eval_config.get("max_steps", self.max_turns)
|
717
|
+
self.difficulty = eval_config.get("difficulty", self.difficulty)
|
718
|
+
self.seed = eval_config.get("seed", self.seed)
|
719
|
+
|
720
|
+
# Extract service settings
|
721
|
+
service_config = config.get("service", {})
|
722
|
+
self.service_base_url = service_config.get("base_url", self.service_base_url)
|
723
|
+
self.service_timeout = service_config.get("timeout", self.service_timeout)
|
724
|
+
|
725
|
+
# Extract output settings
|
726
|
+
output_config = config.get("output", {})
|
727
|
+
self.save_traces = output_config.get("save_traces", self.save_traces)
|
728
|
+
self.save_detailed_results = output_config.get(
|
729
|
+
"save_detailed_results", self.save_detailed_results
|
730
|
+
)
|
731
|
+
|
732
|
+
# Extract custom OpenAI endpoint settings
|
733
|
+
openai_config = config.get("openai", {})
|
734
|
+
self.custom_openai_base_url = openai_config.get("base_url", self.custom_openai_base_url)
|
735
|
+
self.custom_openai_api_key = openai_config.get("api_key", self.custom_openai_api_key)
|
736
|
+
|
737
|
+
def _configure_custom_openai(self):
|
738
|
+
"""Configure environment variables for custom OpenAI endpoint if specified."""
|
739
|
+
if self.custom_openai_base_url:
|
740
|
+
# Ensure the base URL ends with /v1 for OpenAI compatibility
|
741
|
+
base_url = self.custom_openai_base_url.rstrip("/")
|
742
|
+
if not base_url.endswith("/v1"):
|
743
|
+
base_url += "/v1"
|
744
|
+
|
745
|
+
# Set environment variables for OpenAI SDK
|
746
|
+
os.environ["OPENAI_BASE_URL"] = base_url
|
747
|
+
os.environ["OPENAI_API_KEY"] = self.custom_openai_api_key
|
748
|
+
|
749
|
+
print(f"🔧 Configured custom OpenAI endpoint: {base_url}")
|
750
|
+
print(f" API Key: {self.custom_openai_api_key}")
|
751
|
+
|
752
|
+
# Auto-detect if this looks like a fine-tuned model and add ft: regex support
|
753
|
+
if self.model_name and (
|
754
|
+
self.model_name.startswith("ft:") or "lora" in self.model_name.lower()
|
755
|
+
):
|
756
|
+
self._add_ft_regex_support()
|
757
|
+
|
758
|
+
def _add_ft_regex_support(self):
|
759
|
+
"""Add ft: regex pattern to OpenAI naming regexes if not already present."""
|
760
|
+
try:
|
761
|
+
import re
|
762
|
+
from synth_ai.lm.core import vendor_clients
|
763
|
+
|
764
|
+
# Check if ft: pattern already exists
|
765
|
+
ft_pattern = re.compile(r"^ft:.*$")
|
766
|
+
if not any(
|
767
|
+
pattern.pattern == ft_pattern.pattern
|
768
|
+
for pattern in vendor_clients.openai_naming_regexes
|
769
|
+
):
|
770
|
+
# Add ft: pattern at the beginning to catch all fine-tuned models
|
771
|
+
vendor_clients.openai_naming_regexes.insert(0, ft_pattern)
|
772
|
+
print(f"✅ Added ft:* regex pattern for fine-tuned model support")
|
773
|
+
except Exception as e:
|
774
|
+
print(f"⚠️ Warning: Could not add ft: regex pattern: {e}")
|
775
|
+
|
776
|
+
def set_custom_endpoint(self, base_url: str, api_key: str = "dummy"):
|
777
|
+
"""Programmatically set custom OpenAI endpoint."""
|
778
|
+
self.custom_openai_base_url = base_url
|
779
|
+
self.custom_openai_api_key = api_key
|
780
|
+
self._configure_custom_openai()
|
781
|
+
|
782
|
+
|
783
|
+
# --- Global Config ---
|
784
|
+
config = CrafterConfig()
|
785
|
+
|
786
|
+
|
787
|
+
# --- Helper to build crafter semantic mapping ---
|
788
|
+
@functools.lru_cache(maxsize=1)
|
789
|
+
def get_crafter_semantic_mapping():
|
790
|
+
"""Build the crafter semantic ID to item name mapping."""
|
791
|
+
try:
|
792
|
+
import crafter
|
793
|
+
import itertools
|
794
|
+
|
795
|
+
# Create a dummy env to get ID mappings
|
796
|
+
dummyenv = None
|
797
|
+
try:
|
798
|
+
dummyenv = crafter.Env()
|
799
|
+
max_id = (
|
800
|
+
max(
|
801
|
+
max(dummyenv._world._mat_ids.values()),
|
802
|
+
max(dummyenv._sem_view._obj_ids.values()),
|
803
|
+
)
|
804
|
+
+ 1
|
805
|
+
)
|
806
|
+
id_to_item = ["void"] * max_id
|
807
|
+
for name, ind in itertools.chain(
|
808
|
+
dummyenv._world._mat_ids.items(), dummyenv._sem_view._obj_ids.items()
|
809
|
+
):
|
810
|
+
if name is None:
|
811
|
+
clean = "none"
|
812
|
+
elif hasattr(name, "__name__"):
|
813
|
+
clean = name.__name__
|
814
|
+
else:
|
815
|
+
clean = str(name)
|
816
|
+
id_to_item[ind] = clean.lower()
|
817
|
+
player_idx = id_to_item.index("player")
|
818
|
+
return id_to_item, player_idx
|
819
|
+
finally:
|
820
|
+
if dummyenv:
|
821
|
+
try:
|
822
|
+
dummyenv.close()
|
823
|
+
except Exception:
|
824
|
+
pass
|
825
|
+
del dummyenv
|
826
|
+
except ImportError:
|
827
|
+
# Fallback if crafter is not available
|
828
|
+
return None, None
|
829
|
+
|
830
|
+
|
831
|
+
def format_semantic_map_view(obs_data: Dict[str, Any], view_size: int = 7) -> str:
|
832
|
+
"""Format a semantic map view around the player (ASCII)."""
|
833
|
+
try:
|
834
|
+
# Get mapping list
|
835
|
+
id_to_item, _ = get_crafter_semantic_mapping()
|
836
|
+
if id_to_item is None:
|
837
|
+
return "Map view unavailable (crafter not installed)"
|
838
|
+
|
839
|
+
semantic_map = obs_data.get("semantic_map")
|
840
|
+
player_position = obs_data.get("player_position", [0, 0])
|
841
|
+
|
842
|
+
if semantic_map is None:
|
843
|
+
return "Map view unavailable (no semantic map data)"
|
844
|
+
|
845
|
+
# Ensure numpy array with 2 dimensions
|
846
|
+
sem_arr = np.asarray(semantic_map)
|
847
|
+
if sem_arr.ndim == 1:
|
848
|
+
# Probably flattened; try to infer square size
|
849
|
+
size = int(np.sqrt(sem_arr.size))
|
850
|
+
sem_arr = sem_arr.reshape(size, size)
|
851
|
+
elif sem_arr.ndim != 2:
|
852
|
+
return "Map view unavailable (invalid map dimensionality)"
|
853
|
+
|
854
|
+
px, py = map(int, player_position)
|
855
|
+
half = view_size // 2
|
856
|
+
|
857
|
+
rows = []
|
858
|
+
visible = set()
|
859
|
+
for dy in range(-half, half + 1):
|
860
|
+
row_tokens = []
|
861
|
+
for dx in range(-half, half + 1):
|
862
|
+
x, y = px + dx, py + dy
|
863
|
+
if 0 <= x < sem_arr.shape[0] and 0 <= y < sem_arr.shape[1]:
|
864
|
+
if dx == 0 and dy == 0:
|
865
|
+
token = "player"
|
866
|
+
else:
|
867
|
+
idx = int(sem_arr[x, y])
|
868
|
+
token = id_to_item[idx] if idx < len(id_to_item) else "?"
|
869
|
+
else:
|
870
|
+
token = "void"
|
871
|
+
row_tokens.append(token)
|
872
|
+
if token not in {"void", "player"}:
|
873
|
+
visible.add(token)
|
874
|
+
rows.append(" ".join(row_tokens))
|
875
|
+
|
876
|
+
map_view = f"\nLocal Map View ({view_size}x{view_size}):\n" + "\n".join(rows)
|
877
|
+
if visible:
|
878
|
+
map_view += "\nVisible items: " + ", ".join(sorted(visible))
|
879
|
+
else:
|
880
|
+
map_view += "\nNo special items visible (mostly grass/empty)"
|
881
|
+
return map_view
|
882
|
+
except Exception as e:
|
883
|
+
return f"Map view error: {e}"
|
884
|
+
|
885
|
+
|
886
|
+
# --- Shaped Reward Configuration ---
|
887
|
+
# K-values for shaped reward calculation: reward = sum(K * log(count)) for each achievement
|
888
|
+
ACHIEVEMENT_K_VALUES = {
|
889
|
+
"collect_coal": 3.0,
|
890
|
+
"collect_diamond": 100.0,
|
891
|
+
"collect_drink": 0.1,
|
892
|
+
"collect_iron": 10.0,
|
893
|
+
"collect_sapling": 0.1,
|
894
|
+
"collect_stone": 1.0,
|
895
|
+
"collect_wood": 1.0,
|
896
|
+
"defeat_skeleton": 1.0,
|
897
|
+
"defeat_zombie": 1.0,
|
898
|
+
"eat_cow": 1.0,
|
899
|
+
"eat_plant": 0.1,
|
900
|
+
"make_iron_pickaxe": 30.0,
|
901
|
+
"make_iron_sword": 30.0,
|
902
|
+
"make_stone_pickaxe": 10.0,
|
903
|
+
"make_stone_sword": 10.0,
|
904
|
+
"make_wood_pickaxe": 3.0,
|
905
|
+
"make_wood_sword": 3.0,
|
906
|
+
"place_furnace": 10.0,
|
907
|
+
"place_plant": 0.1,
|
908
|
+
"place_stone": 1.0,
|
909
|
+
"place_table": 3.0,
|
910
|
+
"wake_up": 0.1,
|
911
|
+
}
|
912
|
+
|
913
|
+
|
914
|
+
# --- Tool Definitions ---
|
915
|
+
def get_openai_tools():
|
916
|
+
"""Get OpenAI-compatible tool definitions."""
|
917
|
+
return [
|
918
|
+
{
|
919
|
+
"type": "function",
|
920
|
+
"function": {
|
921
|
+
"name": "interact",
|
922
|
+
"description": "Perform 1-5 actions in sequence in the Crafter environment.",
|
923
|
+
"parameters": {
|
924
|
+
"type": "object",
|
925
|
+
"properties": {
|
926
|
+
"actions": {
|
927
|
+
"type": "array",
|
928
|
+
"items": {"type": "string"},
|
929
|
+
"description": "List of 1-5 action names to execute in sequence (e.g., ['move_up', 'do', 'mine_down'])"
|
930
|
+
},
|
931
|
+
"reasoning": {
|
932
|
+
"type": "string",
|
933
|
+
"description": "Brief explanation of why these actions were chosen"
|
934
|
+
}
|
935
|
+
},
|
936
|
+
"required": ["actions", "reasoning"]
|
937
|
+
}
|
938
|
+
}
|
939
|
+
},
|
940
|
+
{
|
941
|
+
"type": "function",
|
942
|
+
"function": {
|
943
|
+
"name": "terminate",
|
944
|
+
"description": "End the episode when finished or no progress can be made.",
|
945
|
+
"parameters": {
|
946
|
+
"type": "object",
|
947
|
+
"properties": {
|
948
|
+
"reason": {
|
949
|
+
"type": "string",
|
950
|
+
"description": "Reason for termination"
|
951
|
+
}
|
952
|
+
},
|
953
|
+
"required": ["reason"]
|
954
|
+
}
|
955
|
+
}
|
956
|
+
}
|
957
|
+
]
|
958
|
+
|
959
|
+
|
960
|
+
# --- Shaped Reward Helper ---
|
961
|
+
def calculate_shaped_reward(achievement_counts: Dict[str, int]) -> Dict[str, Any]:
|
962
|
+
"""Calculate shaped reward using K * log(count) for each achievement."""
|
963
|
+
total_reward = 0.0
|
964
|
+
reward_breakdown = {}
|
965
|
+
|
966
|
+
for achievement, count in achievement_counts.items():
|
967
|
+
if count > 0 and achievement in ACHIEVEMENT_K_VALUES:
|
968
|
+
k_value = ACHIEVEMENT_K_VALUES[achievement]
|
969
|
+
# Use log(count + 1) to handle count=0 case gracefully
|
970
|
+
reward_contribution = k_value * math.log(count + 1)
|
971
|
+
total_reward += reward_contribution
|
972
|
+
reward_breakdown[achievement] = {
|
973
|
+
"count": count,
|
974
|
+
"k_value": k_value,
|
975
|
+
"contribution": reward_contribution,
|
976
|
+
}
|
977
|
+
|
978
|
+
return {"total_shaped_reward": total_reward, "breakdown": reward_breakdown}
|
979
|
+
|
980
|
+
|
981
|
+
# --- Base ReAct Agent ---
|
982
|
+
class BaseReActAgent:
|
983
|
+
"""Base ReAct agent for environment interaction."""
|
984
|
+
|
985
|
+
def __init__(self, model_name: str, max_turns: int = 20, verbose: bool = False, tracer: Optional[SessionTracer] = None):
|
986
|
+
self.model_name = model_name
|
987
|
+
self.max_turns = max_turns
|
988
|
+
self.verbose = verbose
|
989
|
+
self.history = []
|
990
|
+
self.system_name = "base-react-agent"
|
991
|
+
self.tools = get_openai_tools()
|
992
|
+
self.tracer = tracer
|
993
|
+
# Unique system ID for this agent instance
|
994
|
+
import uuid
|
995
|
+
self.system_id = uuid.uuid4()
|
996
|
+
|
997
|
+
# Agent state tracking
|
998
|
+
self.agent_state = {
|
999
|
+
"message_history": [], # LLM conversation history
|
1000
|
+
"steps_taken": 0,
|
1001
|
+
"steps_remaining": max_turns,
|
1002
|
+
"total_tokens_used": 0,
|
1003
|
+
"tool_calls_made": 0,
|
1004
|
+
"current_turn": 0
|
1005
|
+
}
|
1006
|
+
|
1007
|
+
def decide(self, obs: str, system_message: str, turn: int) -> Dict[str, Any]:
|
1008
|
+
"""Get agent decision based on observation."""
|
1009
|
+
# Update agent state
|
1010
|
+
self.agent_state["current_turn"] = turn
|
1011
|
+
self.agent_state["steps_taken"] = turn
|
1012
|
+
self.agent_state["steps_remaining"] = self.max_turns - turn
|
1013
|
+
|
1014
|
+
# Create conversation context
|
1015
|
+
context = f"Turn {turn + 1}/{self.max_turns}\n\n{obs}"
|
1016
|
+
messages = [
|
1017
|
+
{"role": "system", "content": system_message},
|
1018
|
+
{"role": "user", "content": context}
|
1019
|
+
]
|
1020
|
+
|
1021
|
+
# Only keep the last N messages to avoid huge state
|
1022
|
+
max_history_length = 20
|
1023
|
+
if len(self.agent_state["message_history"]) > max_history_length:
|
1024
|
+
# Keep system message and last N-1 messages
|
1025
|
+
self.agent_state["message_history"] = (
|
1026
|
+
[self.agent_state["message_history"][0]] + # Keep first system message
|
1027
|
+
self.agent_state["message_history"][-(max_history_length-1):]
|
1028
|
+
)
|
1029
|
+
|
1030
|
+
# Add current messages to history
|
1031
|
+
self.agent_state["message_history"].append({"role": "system", "content": system_message})
|
1032
|
+
self.agent_state["message_history"].append({"role": "user", "content": context})
|
1033
|
+
|
1034
|
+
# Capture system state before LLM call (compress message history)
|
1035
|
+
system_state_before = self.agent_state.copy()
|
1036
|
+
# Truncate message history in the saved state to save space
|
1037
|
+
if "message_history" in system_state_before and len(system_state_before["message_history"]) > 4:
|
1038
|
+
system_state_before["message_history"] = (
|
1039
|
+
system_state_before["message_history"][:2] + # First 2 messages
|
1040
|
+
["... truncated ..."] +
|
1041
|
+
system_state_before["message_history"][-2:] # Last 2 messages
|
1042
|
+
)
|
1043
|
+
|
1044
|
+
# Use langfuse generation (langfuse 3.x API - not context manager)
|
1045
|
+
# Suppress all stdout during Langfuse operations
|
1046
|
+
import sys
|
1047
|
+
import io
|
1048
|
+
old_stdout = sys.stdout
|
1049
|
+
old_stderr = sys.stderr
|
1050
|
+
sys.stdout = io.StringIO()
|
1051
|
+
sys.stderr = io.StringIO()
|
1052
|
+
try:
|
1053
|
+
langfuse_client = Langfuse()
|
1054
|
+
finally:
|
1055
|
+
sys.stdout = old_stdout
|
1056
|
+
sys.stderr = old_stderr
|
1057
|
+
|
1058
|
+
# Create generation (not as context manager)
|
1059
|
+
generation = langfuse_client.generation(
|
1060
|
+
name=f"crafter_agent_turn_{turn}",
|
1061
|
+
model=self.model_name,
|
1062
|
+
input=messages,
|
1063
|
+
metadata={
|
1064
|
+
"turn": turn,
|
1065
|
+
"agent_type": self.system_name,
|
1066
|
+
"tools_available": len(self.tools)
|
1067
|
+
}
|
1068
|
+
)
|
1069
|
+
|
1070
|
+
# Store langfuse client for cleanup
|
1071
|
+
self._langfuse_client = langfuse_client
|
1072
|
+
|
1073
|
+
try:
|
1074
|
+
# Generate response using OpenAI client (v1.0+ API)
|
1075
|
+
prompt_size = sum(len(str(m.get('content', ''))) for m in messages)
|
1076
|
+
llm_start = time.time()
|
1077
|
+
response = openai.chat.completions.create(
|
1078
|
+
model=self.model_name,
|
1079
|
+
messages=messages,
|
1080
|
+
tools=self.tools,
|
1081
|
+
tool_choice="auto",
|
1082
|
+
temperature=0.0
|
1083
|
+
)
|
1084
|
+
llm_end = time.time()
|
1085
|
+
|
1086
|
+
# Update generation with output
|
1087
|
+
generation.update(
|
1088
|
+
output=response.choices[0].message.model_dump() if response.choices else None,
|
1089
|
+
usage={
|
1090
|
+
"prompt_tokens": response.usage.prompt_tokens if response.usage else None,
|
1091
|
+
"completion_tokens": response.usage.completion_tokens if response.usage else None,
|
1092
|
+
"total_tokens": response.usage.total_tokens if response.usage else None
|
1093
|
+
}
|
1094
|
+
)
|
1095
|
+
|
1096
|
+
# End the generation
|
1097
|
+
generation.end()
|
1098
|
+
|
1099
|
+
# Extract tool calls from response
|
1100
|
+
tool_calls = response.choices[0].message.tool_calls
|
1101
|
+
|
1102
|
+
# Handle case where tool_calls is None or empty (graceful fallback)
|
1103
|
+
if not tool_calls:
|
1104
|
+
if self.verbose:
|
1105
|
+
print(f"[WARNING] No tool calls returned by LLM, using default action")
|
1106
|
+
decision = {
|
1107
|
+
"name": "interact",
|
1108
|
+
"parameters": {
|
1109
|
+
"actions": ["do"],
|
1110
|
+
"reasoning": "Default action - no tool call received",
|
1111
|
+
},
|
1112
|
+
}
|
1113
|
+
else:
|
1114
|
+
tool_call_data = tool_calls[0]
|
1115
|
+
tool_name = tool_call_data.function.name
|
1116
|
+
tool_arguments = json.loads(tool_call_data.function.arguments)
|
1117
|
+
decision = {"name": tool_name, "parameters": tool_arguments}
|
1118
|
+
|
1119
|
+
# Update agent state with response
|
1120
|
+
if response.usage:
|
1121
|
+
self.agent_state["total_tokens_used"] += response.usage.total_tokens
|
1122
|
+
self.agent_state["tool_calls_made"] += 1
|
1123
|
+
|
1124
|
+
# Add assistant message to history
|
1125
|
+
assistant_message = {
|
1126
|
+
"role": "assistant",
|
1127
|
+
"content": response.choices[0].message.content if response.choices else None,
|
1128
|
+
"tool_calls": [
|
1129
|
+
{
|
1130
|
+
"id": tc.id,
|
1131
|
+
"function": {
|
1132
|
+
"name": tc.function.name,
|
1133
|
+
"arguments": tc.function.arguments
|
1134
|
+
},
|
1135
|
+
"type": tc.type
|
1136
|
+
} for tc in tool_calls
|
1137
|
+
] if tool_calls else []
|
1138
|
+
}
|
1139
|
+
self.agent_state["message_history"].append(assistant_message)
|
1140
|
+
|
1141
|
+
# Capture system state after LLM call (compress message history)
|
1142
|
+
system_state_after = self.agent_state.copy()
|
1143
|
+
# Truncate message history in the saved state to save space
|
1144
|
+
if "message_history" in system_state_after and len(system_state_after["message_history"]) > 4:
|
1145
|
+
system_state_after["message_history"] = (
|
1146
|
+
system_state_after["message_history"][:2] + # First 2 messages
|
1147
|
+
["... truncated ..."] +
|
1148
|
+
system_state_after["message_history"][-2:] # Last 2 messages
|
1149
|
+
)
|
1150
|
+
|
1151
|
+
# Record LLM call as a CAISEvent (internal to agent, NOT a message)
|
1152
|
+
if self.tracer:
|
1153
|
+
try:
|
1154
|
+
# Create LLM call record with all info needed to reproduce
|
1155
|
+
llm_call_record = {
|
1156
|
+
"model": self.model_name,
|
1157
|
+
"messages": messages,
|
1158
|
+
"tools": self.tools,
|
1159
|
+
"tool_choice": "auto",
|
1160
|
+
"temperature": 0.0,
|
1161
|
+
"response": {
|
1162
|
+
"id": response.id if response else None,
|
1163
|
+
"choices": [{
|
1164
|
+
"index": 0,
|
1165
|
+
"message": {
|
1166
|
+
"role": "assistant",
|
1167
|
+
"content": response.choices[0].message.content if response.choices else None,
|
1168
|
+
"tool_calls": [
|
1169
|
+
{
|
1170
|
+
"id": tc.id,
|
1171
|
+
"function": {
|
1172
|
+
"name": tc.function.name,
|
1173
|
+
"arguments": tc.function.arguments
|
1174
|
+
},
|
1175
|
+
"type": tc.type
|
1176
|
+
} for tc in tool_calls
|
1177
|
+
] if tool_calls else None
|
1178
|
+
},
|
1179
|
+
"finish_reason": response.choices[0].finish_reason if response.choices else None
|
1180
|
+
}] if response.choices else [],
|
1181
|
+
"usage": {
|
1182
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
1183
|
+
"completion_tokens": response.usage.completion_tokens,
|
1184
|
+
"total_tokens": response.usage.total_tokens
|
1185
|
+
} if response.usage else None
|
1186
|
+
}
|
1187
|
+
}
|
1188
|
+
|
1189
|
+
llm_event = CAISEvent(
|
1190
|
+
time_record=TimeRecord(
|
1191
|
+
event_time=datetime.now().isoformat(),
|
1192
|
+
message_time=turn
|
1193
|
+
),
|
1194
|
+
system_instance_id=f"{self.system_name}_{self.model_name}",
|
1195
|
+
system_state_before=system_state_before,
|
1196
|
+
system_state_after=system_state_after,
|
1197
|
+
llm_call_records=[llm_call_record], # Include the LLM call record
|
1198
|
+
metadata={
|
1199
|
+
"model_name": self.model_name,
|
1200
|
+
"prompt_tokens": response.usage.prompt_tokens if response.usage else None,
|
1201
|
+
"completion_tokens": response.usage.completion_tokens if response.usage else None,
|
1202
|
+
"total_tokens": response.usage.total_tokens if response.usage else None,
|
1203
|
+
"turn": turn
|
1204
|
+
}
|
1205
|
+
)
|
1206
|
+
|
1207
|
+
#print(f"🤖 Created CAISEvent: {llm_event.system_instance_id}, has tracer: {self.tracer is not None}")
|
1208
|
+
|
1209
|
+
if hasattr(self.tracer, 'current_session') and self.tracer.current_session:
|
1210
|
+
self.tracer.record_event(llm_event)
|
1211
|
+
# Store the last event for progress tracking
|
1212
|
+
self.last_cais_event = llm_event
|
1213
|
+
if self.verbose:
|
1214
|
+
print(f"✅ Added CAISEvent for turn {turn}")
|
1215
|
+
else:
|
1216
|
+
if self.verbose:
|
1217
|
+
print(f"⚠️ No current_session in tracer")
|
1218
|
+
|
1219
|
+
except Exception as e:
|
1220
|
+
if self.verbose:
|
1221
|
+
print(f"Warning: Failed to capture LLM event: {e}")
|
1222
|
+
import traceback
|
1223
|
+
traceback.print_exc()
|
1224
|
+
# Always print errors for debugging
|
1225
|
+
print(f"❌ Error adding CAISEvent: {e}")
|
1226
|
+
import traceback
|
1227
|
+
traceback.print_exc()
|
1228
|
+
|
1229
|
+
except Exception as e:
|
1230
|
+
raise e
|
1231
|
+
|
1232
|
+
return decision
|
1233
|
+
|
1234
|
+
|
1235
|
+
# --- Crafter ReAct Agent ---
|
1236
|
+
class CrafterReActAgent(BaseReActAgent):
|
1237
|
+
"""ReAct agent for Crafter environment."""
|
1238
|
+
|
1239
|
+
def __init__(self, model_name: str, max_turns: int = 20, verbose: bool = False, tracer: Optional[SessionTracer] = None):
|
1240
|
+
super().__init__(model_name, max_turns, verbose, tracer)
|
1241
|
+
self.system_name = "crafter-react-agent"
|
1242
|
+
|
1243
|
+
def get_system_message(self) -> str:
|
1244
|
+
return """You are CrafterAgent playing Crafter survival environment. Your goal is to unlock as many achievements as possible while staying alive.
|
1245
|
+
|
1246
|
+
You will see a semantic map view showing your surroundings. Use this to navigate toward resources.
|
1247
|
+
|
1248
|
+
Key mechanics:
|
1249
|
+
• 'do' action: collect wood from trees, stone from deposits, food from cows/plants
|
1250
|
+
• 'do' does nothing on grass/water - move to find resources first
|
1251
|
+
• Craft progression: wood → table → wood_pickaxe → stone → stone_pickaxe → iron tools
|
1252
|
+
• Sleep when energy low to restore and unlock wake_up achievement
|
1253
|
+
• Use semantic map view to navigate toward resources you can see
|
1254
|
+
|
1255
|
+
Available actions: move_left, move_right, move_up, move_down, do, sleep, place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword, noop
|
1256
|
+
|
1257
|
+
Strategy:
|
1258
|
+
1. Look at the semantic map to see what's around you
|
1259
|
+
2. Move toward trees to collect wood with 'do'
|
1260
|
+
3. Once you have wood, place a table to enable crafting
|
1261
|
+
4. Make a wood pickaxe to collect stone more efficiently
|
1262
|
+
5. Progress to stone pickaxe, then iron tools
|
1263
|
+
6. Eat food when health is low, sleep when energy is low
|
1264
|
+
|
1265
|
+
You should provide 1-5 actions in sequence for efficient gameplay. Use the semantic map view to navigate toward visible resources.
|
1266
|
+
|
1267
|
+
Example good action sequences:
|
1268
|
+
- ['move_right', 'move_right', 'do'] (move to tree and collect wood)
|
1269
|
+
- ['place_table', 'make_wood_pickaxe'] (craft progression)
|
1270
|
+
- ['move_up', 'do', 'move_down', 'do'] (collect from multiple resources)
|
1271
|
+
|
1272
|
+
Be strategic and use the map view to find resources! Focus on unlocking achievements."""
|
1273
|
+
|
1274
|
+
def format_observation(self, obs: Dict[str, Any]) -> str:
|
1275
|
+
"""Format observation for Crafter with rich context."""
|
1276
|
+
# Extract key information from observation
|
1277
|
+
health = obs.get("health", 0)
|
1278
|
+
inventory = obs.get("inventory", {})
|
1279
|
+
|
1280
|
+
# Extract health from inventory if not in main obs
|
1281
|
+
if health == 0 and "health" in inventory:
|
1282
|
+
health = inventory["health"]
|
1283
|
+
|
1284
|
+
# Format inventory items (exclude health since we show it separately)
|
1285
|
+
inventory_items = []
|
1286
|
+
for item, count in inventory.items():
|
1287
|
+
if count > 0 and item != "health":
|
1288
|
+
inventory_items.append(f"{item}: {count}")
|
1289
|
+
|
1290
|
+
inventory_str = ", ".join(inventory_items) if inventory_items else "empty"
|
1291
|
+
|
1292
|
+
# Get achievements
|
1293
|
+
achievements = obs.get("achievements") or obs.get("achievements_status", {})
|
1294
|
+
unlocked_achievements = [name for name, unlocked in achievements.items() if unlocked]
|
1295
|
+
achievements_str = ", ".join(unlocked_achievements) if unlocked_achievements else "none"
|
1296
|
+
|
1297
|
+
# Get position and other state
|
1298
|
+
position = obs.get("position", [0, 0])
|
1299
|
+
player_position = obs.get("player_position", position)
|
1300
|
+
player_direction = obs.get("player_direction", [0, 1])
|
1301
|
+
num_steps = obs.get("num_steps_taken", 0)
|
1302
|
+
|
1303
|
+
# Check termination status
|
1304
|
+
terminated = obs.get("terminated", False)
|
1305
|
+
|
1306
|
+
# Get semantic map view
|
1307
|
+
map_view = format_semantic_map_view(obs, view_size=5)
|
1308
|
+
|
1309
|
+
return (
|
1310
|
+
f"Crafter Game State:\n"
|
1311
|
+
f"Step: {num_steps}\n"
|
1312
|
+
f"Health: {health}\n"
|
1313
|
+
f"Position: {player_position}\n"
|
1314
|
+
f"Direction: {player_direction}\n"
|
1315
|
+
f"Inventory: {inventory_str}\n"
|
1316
|
+
f"Achievements: {achievements_str}\n"
|
1317
|
+
f"Terminated: {terminated}\n"
|
1318
|
+
f"{map_view}\n\n"
|
1319
|
+
f"Available actions: move_left, move_right, move_up, move_down, do, sleep, place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword, noop\n\n"
|
1320
|
+
f"Key mechanics:\n"
|
1321
|
+
f"• 'do' action: collect wood from trees, stone from deposits, food from cows/plants\n"
|
1322
|
+
f"• 'do' does nothing on grass/water - move to find resources\n"
|
1323
|
+
f"• Craft progression: wood → table → wood_pickaxe → stone → stone_pickaxe → iron tools\n"
|
1324
|
+
f"• Sleep when energy low to restore and unlock wake_up achievement\n\n"
|
1325
|
+
f"Choose 1-5 actions to execute in sequence. Focus on exploring to find resources and crafting tools to unlock achievements."
|
1326
|
+
)
|
1327
|
+
|
1328
|
+
|
1329
|
+
# --- Episode Runner ---
|
1330
|
+
async def run_single_episode(
|
1331
|
+
client: AsyncClient, agent: CrafterReActAgent, task_instance, instance_num: int, traces_dir: str = "traces",
|
1332
|
+
config=None, episode_progress_bars=None, all_achievements=None, all_invalid_actions=None, all_total_actions=None,
|
1333
|
+
experiment_id: Optional[str] = None
|
1334
|
+
) -> Dict[str, Any]:
|
1335
|
+
# Timing tracking
|
1336
|
+
episode_start_time = time.time()
|
1337
|
+
step_times = [] # Time per step
|
1338
|
+
env_times = [] # Time for environment calls
|
1339
|
+
agent_times = [] # Time for agent LLM calls
|
1340
|
+
"""Run a single Crafter episode and return episode metrics."""
|
1341
|
+
# Create session tracer for this episode with hooks and DuckDB storage
|
1342
|
+
# DuckDB will be auto-enabled based on LOCAL_SYNTH config
|
1343
|
+
tracer = SessionTracer(
|
1344
|
+
traces_dir,
|
1345
|
+
hooks=CRAFTER_HOOKS,
|
1346
|
+
duckdb_path="synth_ai/traces/crafter_traces.duckdb",
|
1347
|
+
experiment_id=experiment_id
|
1348
|
+
)
|
1349
|
+
session_id = f"crafter_episode_{instance_num}_{task_instance.id}"
|
1350
|
+
tracer.start_session(session_id)
|
1351
|
+
|
1352
|
+
# Progress bars already initialized in batch processing
|
1353
|
+
|
1354
|
+
# Create system IDs
|
1355
|
+
import uuid
|
1356
|
+
runtime_id = uuid.uuid4() # Runtime converts tool calls to actions
|
1357
|
+
environment_id = task_instance.id # Use task instance ID for environment
|
1358
|
+
|
1359
|
+
# Add episode metadata
|
1360
|
+
tracer.add_session_metadata("episode_config", {
|
1361
|
+
"instance_num": instance_num,
|
1362
|
+
"task_instance_id": str(task_instance.id),
|
1363
|
+
"difficulty": task_instance.metadata.difficulty,
|
1364
|
+
"max_turns": agent.max_turns,
|
1365
|
+
"model_name": agent.model_name,
|
1366
|
+
"agent_type": agent.system_name
|
1367
|
+
})
|
1368
|
+
|
1369
|
+
# Add system ID mapping
|
1370
|
+
tracer.add_session_metadata("system_ids", {
|
1371
|
+
"agent": str(agent.system_id),
|
1372
|
+
"runtime": str(runtime_id),
|
1373
|
+
"environment": str(environment_id)
|
1374
|
+
})
|
1375
|
+
|
1376
|
+
# Update agent with tracer
|
1377
|
+
agent.tracer = tracer
|
1378
|
+
|
1379
|
+
try:
|
1380
|
+
# Create environment using the task instance
|
1381
|
+
create_resp = await retry_http_request(
|
1382
|
+
client, "POST", f"/env/CrafterClassic/initialize",
|
1383
|
+
json={"task_instance": await task_instance.serialize()}
|
1384
|
+
)
|
1385
|
+
|
1386
|
+
if create_resp.status_code != 200:
|
1387
|
+
return {
|
1388
|
+
"eval_metric": 0.0,
|
1389
|
+
"rubric": {},
|
1390
|
+
"total_reward": 0.0,
|
1391
|
+
"num_achievements": 0,
|
1392
|
+
"terminated": False,
|
1393
|
+
"error": True,
|
1394
|
+
}
|
1395
|
+
|
1396
|
+
env_id = create_resp.json()["env_id"]
|
1397
|
+
|
1398
|
+
# Get initial observation
|
1399
|
+
obs = create_resp.json()["observation"]
|
1400
|
+
formatted_obs = agent.format_observation(obs)
|
1401
|
+
|
1402
|
+
# DEBUG: Print initial state (minimal)
|
1403
|
+
# print(
|
1404
|
+
# f"\n Instance {instance_num}: Starting Crafter survival ({task_instance.metadata.difficulty}, {agent.max_turns} turns max)"
|
1405
|
+
# )
|
1406
|
+
|
1407
|
+
# Track episode metrics
|
1408
|
+
total_reward = 0.0
|
1409
|
+
termination_reason = "max_turns_reached" # Default assumption
|
1410
|
+
final_achievements = {}
|
1411
|
+
num_achievements = 0
|
1412
|
+
terminated = False
|
1413
|
+
rollout_length = 0
|
1414
|
+
|
1415
|
+
# Run episode without tqdm progress bar (we use custom progress display)
|
1416
|
+
# Disable tqdm to avoid conflicts with custom progress bars
|
1417
|
+
episode_progress = range(agent.max_turns)
|
1418
|
+
|
1419
|
+
for turn in episode_progress:
|
1420
|
+
step_start_time = time.time()
|
1421
|
+
try:
|
1422
|
+
# Create timestep for this turn (needed for record_event to work)
|
1423
|
+
if tracer:
|
1424
|
+
tracer.start_timestep(f"turn_{turn}")
|
1425
|
+
|
1426
|
+
# Record observation message (Environment → Runtime → Agent)
|
1427
|
+
obs_for_trace = compress_observation_for_trace(obs)
|
1428
|
+
|
1429
|
+
obs_message = create_message(
|
1430
|
+
content=obs_for_trace, # Compressed observation
|
1431
|
+
message_type="observation",
|
1432
|
+
origin_system_id=environment_id, # From environment
|
1433
|
+
turn=turn
|
1434
|
+
)
|
1435
|
+
# Add message directly to session history
|
1436
|
+
if tracer.current_session:
|
1437
|
+
tracer.current_session.add_message(obs_message)
|
1438
|
+
|
1439
|
+
# Get agent decision
|
1440
|
+
agent_start_time = time.time()
|
1441
|
+
action = agent.decide(formatted_obs, agent.get_system_message(), turn)
|
1442
|
+
agent_end_time = time.time()
|
1443
|
+
agent_times.append(agent_end_time - agent_start_time)
|
1444
|
+
# print(f" ✅ Agent decision received: {action}")
|
1445
|
+
|
1446
|
+
# Record tool call message (Agent → Runtime)
|
1447
|
+
tool_call_message = create_message(
|
1448
|
+
content=[{
|
1449
|
+
"tool": action["name"],
|
1450
|
+
"args": action["parameters"]
|
1451
|
+
}],
|
1452
|
+
message_type="tool_call",
|
1453
|
+
origin_system_id=agent.system_id, # From agent
|
1454
|
+
turn=turn
|
1455
|
+
)
|
1456
|
+
# Add message directly to session history
|
1457
|
+
if tracer.current_session:
|
1458
|
+
tracer.current_session.add_message(tool_call_message)
|
1459
|
+
|
1460
|
+
# # DEBUG: Print agent decision with safer access
|
1461
|
+
# try:
|
1462
|
+
# actions = action.get('parameters', {}).get('actions', action.get('arguments', {}).get('actions', []))
|
1463
|
+
# reasoning = action.get('parameters', {}).get('reasoning', action.get('arguments', {}).get('reasoning', 'no reasoning'))
|
1464
|
+
# #print(f" Turn {turn+1}: Agent chose {actions} - {reasoning}")
|
1465
|
+
# except Exception as e:
|
1466
|
+
# print(f" Turn {turn+1}: Agent action structure: {action}")
|
1467
|
+
# print(f" Error parsing action: {e}")
|
1468
|
+
|
1469
|
+
# Check for termination
|
1470
|
+
if action["name"] == "terminate":
|
1471
|
+
reason = action.get("parameters", {}).get(
|
1472
|
+
"reason", action.get("arguments", {}).get("reason", "no reason given")
|
1473
|
+
)
|
1474
|
+
termination_reason = f"agent_terminate: {reason}"
|
1475
|
+
# Update tqdm progress bar for early termination
|
1476
|
+
if episode_progress_bars is not None and instance_num in episode_progress_bars:
|
1477
|
+
episode_progress_bars[instance_num].update(1)
|
1478
|
+
episode_progress_bars[instance_num].set_description(
|
1479
|
+
f"Episode {instance_num:2d} [T]" # T for terminated
|
1480
|
+
)
|
1481
|
+
break
|
1482
|
+
|
1483
|
+
# Execute actions in environment with safer access
|
1484
|
+
action_sequence = action.get("parameters", {}).get(
|
1485
|
+
"actions", action.get("arguments", {}).get("actions", [])
|
1486
|
+
)
|
1487
|
+
if not action_sequence:
|
1488
|
+
print(f" ⚠️ No actions found in: {action}")
|
1489
|
+
termination_reason = "no_actions_provided"
|
1490
|
+
break
|
1491
|
+
|
1492
|
+
# Convert action names to integers using the proper action map
|
1493
|
+
# Define the proper Crafter action mapping
|
1494
|
+
CRAFTER_ACTION_MAP = {
|
1495
|
+
"noop": 0,
|
1496
|
+
"move_left": 1,
|
1497
|
+
"move_right": 2,
|
1498
|
+
"move_up": 3,
|
1499
|
+
"move_down": 4,
|
1500
|
+
"do": 5,
|
1501
|
+
"sleep": 6,
|
1502
|
+
"place_stone": 7,
|
1503
|
+
"place_table": 8,
|
1504
|
+
"place_furnace": 9,
|
1505
|
+
"place_plant": 10,
|
1506
|
+
"make_wood_pickaxe": 11,
|
1507
|
+
"make_stone_pickaxe": 12,
|
1508
|
+
"make_iron_pickaxe": 13,
|
1509
|
+
"make_wood_sword": 14,
|
1510
|
+
"make_stone_sword": 15,
|
1511
|
+
"make_iron_sword": 16,
|
1512
|
+
}
|
1513
|
+
|
1514
|
+
action_ints = []
|
1515
|
+
for action_name in action_sequence:
|
1516
|
+
if action_name in CRAFTER_ACTION_MAP:
|
1517
|
+
action_int = CRAFTER_ACTION_MAP[action_name]
|
1518
|
+
else:
|
1519
|
+
action_int = 0 # Default to noop
|
1520
|
+
action_ints.append(action_int)
|
1521
|
+
|
1522
|
+
# Record runtime event before executing actions
|
1523
|
+
prev_obs = compress_observation_for_trace(obs)
|
1524
|
+
|
1525
|
+
runtime_event = RuntimeEvent(
|
1526
|
+
system_state_before={"observation": prev_obs},
|
1527
|
+
actions=action_ints,
|
1528
|
+
metadata={
|
1529
|
+
"action_names": action_sequence,
|
1530
|
+
"action_reasoning": action.get("parameters", {}).get("reasoning", ""),
|
1531
|
+
"turn": turn
|
1532
|
+
}
|
1533
|
+
)
|
1534
|
+
|
1535
|
+
# Record action messages (Runtime → Environment)
|
1536
|
+
for action_int in action_ints:
|
1537
|
+
action_message = create_message(
|
1538
|
+
content={
|
1539
|
+
"action": action_int,
|
1540
|
+
"action_type": "crafter_action"
|
1541
|
+
},
|
1542
|
+
message_type="action",
|
1543
|
+
origin_system_id=runtime_id, # From runtime
|
1544
|
+
turn=turn
|
1545
|
+
)
|
1546
|
+
if tracer.current_session:
|
1547
|
+
tracer.current_session.add_message(action_message)
|
1548
|
+
|
1549
|
+
# Execute each action individually (Crafter expects single actions)
|
1550
|
+
for i, action_int in enumerate(action_ints):
|
1551
|
+
try:
|
1552
|
+
# Time just the HTTP request
|
1553
|
+
env_start_time = time.time()
|
1554
|
+
step_resp = await retry_http_request(
|
1555
|
+
client, "POST", f"/env/CrafterClassic/step",
|
1556
|
+
json={
|
1557
|
+
"env_id": env_id,
|
1558
|
+
"request_id": str(uuid.uuid4()),
|
1559
|
+
"action": {
|
1560
|
+
"tool_calls": [{"tool": "interact", "args": {"action": action_int}}]
|
1561
|
+
},
|
1562
|
+
}
|
1563
|
+
)
|
1564
|
+
env_end_time = time.time()
|
1565
|
+
env_times.append(env_end_time - env_start_time)
|
1566
|
+
|
1567
|
+
if step_resp.status_code != 200:
|
1568
|
+
print(
|
1569
|
+
f" ❌ Action {i + 1} failed: {step_resp.status_code}: {step_resp.text}"
|
1570
|
+
)
|
1571
|
+
termination_reason = f"http_error: {step_resp.status_code}"
|
1572
|
+
break
|
1573
|
+
|
1574
|
+
# Update observation after each action
|
1575
|
+
obs = step_resp.json()["observation"]
|
1576
|
+
|
1577
|
+
except Exception as e:
|
1578
|
+
print(f" ❌ Action {i + 1} failed after retries: {type(e).__name__}: {str(e)[:100]}")
|
1579
|
+
termination_reason = f"http_error: {type(e).__name__}"
|
1580
|
+
break
|
1581
|
+
|
1582
|
+
# Check if we broke out of the action loop due to an error
|
1583
|
+
if termination_reason.startswith("http_error"):
|
1584
|
+
break
|
1585
|
+
|
1586
|
+
# Convert observation to compressed format before saving
|
1587
|
+
obs_for_trace = compress_observation_for_trace(obs)
|
1588
|
+
|
1589
|
+
# Record runtime event now that we have the final state
|
1590
|
+
runtime_event.system_state_after = {"observation": obs_for_trace}
|
1591
|
+
if tracer.current_session:
|
1592
|
+
tracer.record_event(runtime_event)
|
1593
|
+
|
1594
|
+
# Record environment event for the state transition
|
1595
|
+
# Extract public and private state from observations
|
1596
|
+
prev_public_state = {
|
1597
|
+
k: v for k, v in prev_obs.items()
|
1598
|
+
if k not in ["reward_last_step", "total_reward_episode", "terminated", "truncated", "tool_error"]
|
1599
|
+
}
|
1600
|
+
prev_private_state = {
|
1601
|
+
"reward_last_step": prev_obs.get("reward_last_step", 0.0),
|
1602
|
+
"total_reward_episode": prev_obs.get("total_reward_episode", 0.0),
|
1603
|
+
"terminated": prev_obs.get("terminated", False),
|
1604
|
+
"truncated": prev_obs.get("truncated", False)
|
1605
|
+
}
|
1606
|
+
|
1607
|
+
new_public_state = {
|
1608
|
+
k: v for k, v in obs_for_trace.items()
|
1609
|
+
if k not in ["reward_last_step", "total_reward_episode", "terminated", "truncated", "tool_error"]
|
1610
|
+
}
|
1611
|
+
new_private_state = {
|
1612
|
+
"reward_last_step": obs_for_trace.get("reward_last_step", 0.0),
|
1613
|
+
"total_reward_episode": obs_for_trace.get("total_reward_episode", 0.0),
|
1614
|
+
"terminated": obs_for_trace.get("terminated", False),
|
1615
|
+
"truncated": obs_for_trace.get("truncated", False)
|
1616
|
+
}
|
1617
|
+
|
1618
|
+
env_event = EnvironmentEvent(
|
1619
|
+
time_record=TimeRecord(
|
1620
|
+
event_time=datetime.now().isoformat(),
|
1621
|
+
message_time=turn
|
1622
|
+
),
|
1623
|
+
system_instance_id=str(environment_id),
|
1624
|
+
system_state_before={
|
1625
|
+
"public_state": prev_public_state,
|
1626
|
+
"private_state": prev_private_state
|
1627
|
+
},
|
1628
|
+
system_state_after={
|
1629
|
+
"public_state": new_public_state,
|
1630
|
+
"private_state": new_private_state
|
1631
|
+
},
|
1632
|
+
reward=obs.get("reward_last_step", 0.0),
|
1633
|
+
terminated=obs.get("terminated", False),
|
1634
|
+
metadata={
|
1635
|
+
"actions_executed": action_sequence,
|
1636
|
+
"turn": turn
|
1637
|
+
}
|
1638
|
+
)
|
1639
|
+
if tracer.current_session:
|
1640
|
+
tracer.record_event(env_event)
|
1641
|
+
|
1642
|
+
# Show final state after all actions
|
1643
|
+
formatted_obs = agent.format_observation(obs)
|
1644
|
+
step_count = obs.get("num_steps_taken", 0)
|
1645
|
+
rollout_length = step_count
|
1646
|
+
position = obs.get("player_position", [0, 0])
|
1647
|
+
# print(f" Turn {turn+1}: Actions completed - Step: {step_count}, Position: {position}")
|
1648
|
+
|
1649
|
+
# Track step timing
|
1650
|
+
step_end_time = time.time()
|
1651
|
+
step_times.append(step_end_time - step_start_time)
|
1652
|
+
|
1653
|
+
# Track progress for this turn
|
1654
|
+
if episode_progress_bars is not None:
|
1655
|
+
# Check what hooks fired by looking at the events we just created
|
1656
|
+
step_code = '#' # Default for regular step
|
1657
|
+
|
1658
|
+
# Collect all hooks that fired with their priorities
|
1659
|
+
all_hooks = []
|
1660
|
+
|
1661
|
+
# Check environment event hooks
|
1662
|
+
if env_event and env_event.event_metadata:
|
1663
|
+
for meta in env_event.event_metadata:
|
1664
|
+
if 'code' in meta:
|
1665
|
+
all_hooks.append((meta.get('priority', 0), meta.get('code', '?')))
|
1666
|
+
|
1667
|
+
# Check agent hooks from the agent's recorded event
|
1668
|
+
if hasattr(agent, 'last_cais_event') and agent.last_cais_event and agent.last_cais_event.event_metadata:
|
1669
|
+
for meta in agent.last_cais_event.event_metadata:
|
1670
|
+
if 'code' in meta:
|
1671
|
+
all_hooks.append((meta.get('priority', 0), meta.get('code', '?')))
|
1672
|
+
|
1673
|
+
# Select the hook with highest priority
|
1674
|
+
if all_hooks:
|
1675
|
+
all_hooks.sort(key=lambda x: x[0], reverse=True) # Sort by priority descending
|
1676
|
+
step_code = all_hooks[0][1] # Get code of highest priority hook
|
1677
|
+
|
1678
|
+
# Update tqdm progress bar for this episode
|
1679
|
+
if instance_num in episode_progress_bars:
|
1680
|
+
# Update the progress bar by 1 step
|
1681
|
+
episode_progress_bars[instance_num].update(1)
|
1682
|
+
|
1683
|
+
# Update the description with achievement count
|
1684
|
+
if step_code in ['E', 'M', 'H']: # On achievements
|
1685
|
+
# Count total achievements for this episode
|
1686
|
+
achv_count = len([k for k, v in all_achievements.items() if v > 0 and k.startswith(f"episode_{instance_num}_")])
|
1687
|
+
episode_progress_bars[instance_num].set_description(
|
1688
|
+
f"Episode {instance_num:2d} [{achv_count} achv]"
|
1689
|
+
)
|
1690
|
+
|
1691
|
+
# Update history with safer access
|
1692
|
+
reasoning = action.get("parameters", {}).get(
|
1693
|
+
"reasoning", action.get("arguments", {}).get("reasoning", "")
|
1694
|
+
)
|
1695
|
+
agent.history.append(f"{', '.join(action_sequence)}: {reasoning[:50]}")
|
1696
|
+
|
1697
|
+
# Track episode progress - Use the FINAL observation from the last action
|
1698
|
+
terminated = obs.get("terminated", False)
|
1699
|
+
step_reward = obs.get("reward", 0.0)
|
1700
|
+
total_reward += step_reward
|
1701
|
+
achievements = obs.get("achievements") or obs.get("achievements_status", {})
|
1702
|
+
|
1703
|
+
# ALWAYS update final_achievements with the latest observation
|
1704
|
+
final_achievements = achievements
|
1705
|
+
|
1706
|
+
num_achievements = sum(1 for v in achievements.values() if v) if achievements else 0
|
1707
|
+
|
1708
|
+
# Update progress bar description with achievements
|
1709
|
+
if episode_progress_bars is not None and instance_num in episode_progress_bars:
|
1710
|
+
episode_progress_bars[instance_num].set_description(
|
1711
|
+
f"Episode {instance_num:2d} [{num_achievements} achv]"
|
1712
|
+
)
|
1713
|
+
|
1714
|
+
# No need to advance turn - we're tracking via message_time
|
1715
|
+
|
1716
|
+
if terminated:
|
1717
|
+
# Check if it's death or other environment termination
|
1718
|
+
health = obs.get("health", 9) # Default health value
|
1719
|
+
if health <= 0:
|
1720
|
+
termination_reason = "death"
|
1721
|
+
else:
|
1722
|
+
termination_reason = "environment_terminated"
|
1723
|
+
break
|
1724
|
+
|
1725
|
+
except Exception as e:
|
1726
|
+
# Error occurred
|
1727
|
+
error_msg = str(e) if str(e) else f"{type(e).__name__}"
|
1728
|
+
termination_reason = f"exception: {error_msg[:50]}"
|
1729
|
+
print(f" ❌ Episode {instance_num} failed with exception: {error_msg}")
|
1730
|
+
import traceback
|
1731
|
+
traceback.print_exc()
|
1732
|
+
break
|
1733
|
+
|
1734
|
+
# Close/finish the tqdm progress bar for this episode
|
1735
|
+
if episode_progress_bars is not None and instance_num in episode_progress_bars:
|
1736
|
+
# Ensure the progress bar reaches 100% if not already
|
1737
|
+
remaining = episode_progress_bars[instance_num].total - episode_progress_bars[instance_num].n
|
1738
|
+
if remaining > 0:
|
1739
|
+
episode_progress_bars[instance_num].update(remaining)
|
1740
|
+
episode_progress_bars[instance_num].close()
|
1741
|
+
|
1742
|
+
# Cleanup
|
1743
|
+
await client.post(f"/env/CrafterClassic/terminate", json={"env_id": env_id})
|
1744
|
+
|
1745
|
+
# Add final episode metadata to trace
|
1746
|
+
tracer.add_session_metadata("episode_results", {
|
1747
|
+
"total_reward": total_reward,
|
1748
|
+
"num_achievements": num_achievements,
|
1749
|
+
"achievements": final_achievements,
|
1750
|
+
"rollout_length": rollout_length,
|
1751
|
+
"terminated": terminated
|
1752
|
+
})
|
1753
|
+
|
1754
|
+
# End session - only upload to DuckDB, no JSON saving
|
1755
|
+
tracer.end_session(save=False, upload_to_db=True)
|
1756
|
+
|
1757
|
+
# Clear the tracer reference to allow garbage collection
|
1758
|
+
agent.tracer = None
|
1759
|
+
|
1760
|
+
# Track achievements globally
|
1761
|
+
if all_achievements is not None and final_achievements:
|
1762
|
+
for achievement, unlocked in final_achievements.items():
|
1763
|
+
if unlocked:
|
1764
|
+
all_achievements[achievement] += 1
|
1765
|
+
|
1766
|
+
# Track invalid actions globally
|
1767
|
+
if all_invalid_actions is not None and tracer.hooks:
|
1768
|
+
# Find the InvalidActionHook and get its tracked actions
|
1769
|
+
for hook in tracer.hooks:
|
1770
|
+
if hasattr(hook, '__class__') and hook.__class__.__name__ == 'InvalidActionHook':
|
1771
|
+
if hasattr(hook, 'invalid_actions'):
|
1772
|
+
for action, count in hook.invalid_actions.items():
|
1773
|
+
all_invalid_actions[action] += count
|
1774
|
+
if hasattr(hook, 'total_actions') and all_total_actions is not None:
|
1775
|
+
for action, count in hook.total_actions.items():
|
1776
|
+
all_total_actions[action] += count
|
1777
|
+
|
1778
|
+
# Calculate K-weighted achievement reward
|
1779
|
+
achievement_reward = 0.0
|
1780
|
+
if final_achievements:
|
1781
|
+
for achievement, unlocked in final_achievements.items():
|
1782
|
+
if unlocked and achievement in ACHIEVEMENT_K_VALUES:
|
1783
|
+
k_value = ACHIEVEMENT_K_VALUES[achievement]
|
1784
|
+
achievement_reward += k_value * math.log(2) # log(1+1) for single achievement
|
1785
|
+
|
1786
|
+
# Use achievement reward as the total reward
|
1787
|
+
total_reward = achievement_reward
|
1788
|
+
|
1789
|
+
# Calculate eval metric and rubric
|
1790
|
+
eval_metric = float(num_achievements) # Simple metric: number of achievements
|
1791
|
+
|
1792
|
+
# Create rubric with specific achievement checks
|
1793
|
+
rubric = {}
|
1794
|
+
if final_achievements:
|
1795
|
+
rubric = {
|
1796
|
+
"collect_wood": 1.0 if final_achievements.get("collect_wood", False) else 0.0,
|
1797
|
+
"collect_stone": 1.0 if final_achievements.get("collect_stone", False) else 0.0,
|
1798
|
+
"collect_coal": 1.0 if final_achievements.get("collect_coal", False) else 0.0,
|
1799
|
+
"collect_iron": 1.0 if final_achievements.get("collect_iron", False) else 0.0,
|
1800
|
+
"collect_diamond": 1.0 if final_achievements.get("collect_diamond", False) else 0.0,
|
1801
|
+
"place_table": 1.0 if final_achievements.get("place_table", False) else 0.0,
|
1802
|
+
"place_furnace": 1.0 if final_achievements.get("place_furnace", False) else 0.0,
|
1803
|
+
"make_wood_pickaxe": 1.0
|
1804
|
+
if final_achievements.get("make_wood_pickaxe", False)
|
1805
|
+
else 0.0,
|
1806
|
+
"make_stone_pickaxe": 1.0
|
1807
|
+
if final_achievements.get("make_stone_pickaxe", False)
|
1808
|
+
else 0.0,
|
1809
|
+
"make_iron_pickaxe": 1.0
|
1810
|
+
if final_achievements.get("make_iron_pickaxe", False)
|
1811
|
+
else 0.0,
|
1812
|
+
"make_wood_sword": 1.0 if final_achievements.get("make_wood_sword", False) else 0.0,
|
1813
|
+
"make_stone_sword": 1.0
|
1814
|
+
if final_achievements.get("make_stone_sword", False)
|
1815
|
+
else 0.0,
|
1816
|
+
"make_iron_sword": 1.0 if final_achievements.get("make_iron_sword", False) else 0.0,
|
1817
|
+
"defeat_skeleton": 1.0 if final_achievements.get("defeat_skeleton", False) else 0.0,
|
1818
|
+
"defeat_zombie": 1.0 if final_achievements.get("defeat_zombie", False) else 0.0,
|
1819
|
+
"wake_up": 1.0 if final_achievements.get("wake_up", False) else 0.0,
|
1820
|
+
"eat_cow": 1.0 if final_achievements.get("eat_cow", False) else 0.0,
|
1821
|
+
"eat_plant": 1.0 if final_achievements.get("eat_plant", False) else 0.0,
|
1822
|
+
}
|
1823
|
+
else:
|
1824
|
+
# Default rubric with all zeros
|
1825
|
+
rubric = {
|
1826
|
+
"collect_wood": 0.0,
|
1827
|
+
"collect_stone": 0.0,
|
1828
|
+
"collect_coal": 0.0,
|
1829
|
+
"collect_iron": 0.0,
|
1830
|
+
"collect_diamond": 0.0,
|
1831
|
+
"place_table": 0.0,
|
1832
|
+
"place_furnace": 0.0,
|
1833
|
+
"make_wood_pickaxe": 0.0,
|
1834
|
+
"make_stone_pickaxe": 0.0,
|
1835
|
+
"make_iron_pickaxe": 0.0,
|
1836
|
+
"make_wood_sword": 0.0,
|
1837
|
+
"make_stone_sword": 0.0,
|
1838
|
+
"make_iron_sword": 0.0,
|
1839
|
+
"defeat_skeleton": 0.0,
|
1840
|
+
"defeat_zombie": 0.0,
|
1841
|
+
"wake_up": 0.0,
|
1842
|
+
"eat_cow": 0.0,
|
1843
|
+
"eat_plant": 0.0,
|
1844
|
+
}
|
1845
|
+
|
1846
|
+
episode_end_time = time.time()
|
1847
|
+
episode_total_time = episode_end_time - episode_start_time
|
1848
|
+
|
1849
|
+
return {
|
1850
|
+
"eval_metric": eval_metric,
|
1851
|
+
"rubric": rubric,
|
1852
|
+
"total_reward": total_reward,
|
1853
|
+
"num_achievements": num_achievements,
|
1854
|
+
"achievements": final_achievements,
|
1855
|
+
"rollout_length": rollout_length,
|
1856
|
+
"terminated": terminated,
|
1857
|
+
"termination_reason": termination_reason,
|
1858
|
+
"error": False,
|
1859
|
+
"timing": {
|
1860
|
+
"episode_total_time": episode_total_time,
|
1861
|
+
"step_times": step_times,
|
1862
|
+
"env_times": env_times,
|
1863
|
+
"agent_times": agent_times,
|
1864
|
+
}
|
1865
|
+
}
|
1866
|
+
|
1867
|
+
except Exception as e:
|
1868
|
+
# Save trace even on error
|
1869
|
+
try:
|
1870
|
+
tracer.add_session_metadata("error", {"error_message": str(e)})
|
1871
|
+
tracer.end_session(save=True)
|
1872
|
+
except:
|
1873
|
+
pass # Don't let trace saving errors mask the original error
|
1874
|
+
|
1875
|
+
error_msg = str(e) if str(e) else f"{type(e).__name__}"
|
1876
|
+
clear_progress_display()
|
1877
|
+
print(f" ❌ Episode {instance_num} failed with outer exception: {error_msg}")
|
1878
|
+
import traceback
|
1879
|
+
traceback.print_exc()
|
1880
|
+
|
1881
|
+
return {
|
1882
|
+
"eval_metric": 0.0,
|
1883
|
+
"rubric": {},
|
1884
|
+
"total_reward": 0.0,
|
1885
|
+
"num_achievements": 0,
|
1886
|
+
"terminated": False,
|
1887
|
+
"termination_reason": f"outer_exception: {error_msg[:50]}",
|
1888
|
+
"error": True,
|
1889
|
+
}
|
1890
|
+
|
1891
|
+
|
1892
|
+
# --- Batch Evaluation ---
|
1893
|
+
async def evaluate_crafter_batch() -> Dict[str, Any]:
|
1894
|
+
"""Evaluate Crafter agent on multiple easy instances."""
|
1895
|
+
print(f"🎯 Evaluating Crafter on {config.num_instances} {config.difficulty} instances...")
|
1896
|
+
|
1897
|
+
# Create experiment context
|
1898
|
+
with DuckDBTraceManager("synth_ai/traces/crafter_traces.duckdb") as db_manager:
|
1899
|
+
experiment_context = create_experiment_context(
|
1900
|
+
db_manager,
|
1901
|
+
experiment_name=None, # Will auto-generate pet name
|
1902
|
+
description=f"Crafter evaluation: {config.num_instances} {config.difficulty} instances with {config.model_name}",
|
1903
|
+
system_name="crafter-react-agent",
|
1904
|
+
system_description=f"ReAct agent for Crafter using {config.model_name}"
|
1905
|
+
)
|
1906
|
+
|
1907
|
+
experiment_id = experiment_context["experiment_id"]
|
1908
|
+
|
1909
|
+
# Print experiment header with clear formatting
|
1910
|
+
print(f"\n{'='*80}")
|
1911
|
+
print(f"🧪 EXPERIMENT: {experiment_context['experiment_name']}")
|
1912
|
+
print(f"{'='*80}")
|
1913
|
+
print(f"📋 Experiment ID: {experiment_id}")
|
1914
|
+
print(f"🤖 System: {experiment_context['system_id']}")
|
1915
|
+
print(f"📌 Version: {experiment_context['system_version_id']}")
|
1916
|
+
print(f"🌿 Git Branch: {experiment_context['git_branch']}")
|
1917
|
+
print(f"📍 Git Commit: {experiment_context['git_commit']}")
|
1918
|
+
print(f"💾 Database: synth_ai/traces/crafter_traces.duckdb")
|
1919
|
+
print(f"{'='*80}\n")
|
1920
|
+
|
1921
|
+
# Traces are now saved to DuckDB only
|
1922
|
+
script_dir = Path(__file__).parent
|
1923
|
+
traces_dir = script_dir / "traces" # Keep for compatibility but not used
|
1924
|
+
|
1925
|
+
# Initialize progress tracking
|
1926
|
+
all_achievements = defaultdict(int) # Track all achievements across episodes
|
1927
|
+
all_invalid_actions = defaultdict(int) # Track all invalid actions across episodes
|
1928
|
+
all_total_actions = defaultdict(int) # Track all actions across episodes
|
1929
|
+
|
1930
|
+
# Print hook legend
|
1931
|
+
print_hook_legend()
|
1932
|
+
|
1933
|
+
# Track trace files created during this run
|
1934
|
+
initial_trace_files = set(traces_dir.glob("*.json"))
|
1935
|
+
|
1936
|
+
# Create tqdm progress bars for all episodes
|
1937
|
+
from tqdm.asyncio import tqdm
|
1938
|
+
episode_pbar_dict = {}
|
1939
|
+
for i in range(config.num_instances):
|
1940
|
+
episode_pbar_dict[i + 1] = tqdm(
|
1941
|
+
total=config.max_turns,
|
1942
|
+
desc=f"Episode {i + 1:2d}",
|
1943
|
+
position=i,
|
1944
|
+
leave=True,
|
1945
|
+
unit="steps"
|
1946
|
+
)
|
1947
|
+
|
1948
|
+
# Get easy task instances using the taskset system
|
1949
|
+
from synth_ai.environments.examples.crafter_classic.taskset import (
|
1950
|
+
CrafterTaskInstance,
|
1951
|
+
CrafterTaskInstanceMetadata,
|
1952
|
+
)
|
1953
|
+
from synth_ai.environments.tasks.core import Impetus, Intent
|
1954
|
+
|
1955
|
+
easy_task_instances = []
|
1956
|
+
for seed in range(config.num_instances):
|
1957
|
+
try:
|
1958
|
+
metadata = CrafterTaskInstanceMetadata(
|
1959
|
+
difficulty=config.difficulty,
|
1960
|
+
seed=seed,
|
1961
|
+
num_trees_radius=5, # Good for easy difficulty
|
1962
|
+
num_cows_radius=2,
|
1963
|
+
num_hostiles_radius=0, # No hostiles for easy
|
1964
|
+
)
|
1965
|
+
task_instance = CrafterTaskInstance(
|
1966
|
+
id=uuid.uuid4(),
|
1967
|
+
impetus=Impetus(
|
1968
|
+
instructions=f"Survive and unlock achievements in an {config.difficulty} environment."
|
1969
|
+
),
|
1970
|
+
intent=Intent(rubric={}, gold_trajectories=None, gold_state_diff={}),
|
1971
|
+
metadata=metadata,
|
1972
|
+
is_reproducible=True,
|
1973
|
+
initial_engine_snapshot=None,
|
1974
|
+
)
|
1975
|
+
easy_task_instances.append(task_instance)
|
1976
|
+
except Exception as e:
|
1977
|
+
print(f" ⚠️ Failed to create task instance for seed {seed}: {e}")
|
1978
|
+
continue
|
1979
|
+
|
1980
|
+
# print(f" 📝 Generated {len(easy_task_instances)} {config.difficulty} task instances")
|
1981
|
+
|
1982
|
+
# Configure connection pool to prevent ReadErrors
|
1983
|
+
limits = httpx.Limits(
|
1984
|
+
max_keepalive_connections=30, # Increase for better concurrency
|
1985
|
+
max_connections=100, # More connections for parallel episodes
|
1986
|
+
keepalive_expiry=60.0 # Longer keepalive
|
1987
|
+
)
|
1988
|
+
transport = httpx.AsyncHTTPTransport(
|
1989
|
+
http2=True,
|
1990
|
+
limits=limits,
|
1991
|
+
retries=3 # Transport-level retries
|
1992
|
+
)
|
1993
|
+
async with AsyncClient(
|
1994
|
+
base_url=config.service_base_url,
|
1995
|
+
timeout=httpx.Timeout(
|
1996
|
+
connect=5.0,
|
1997
|
+
read=HTTP_TIMEOUT,
|
1998
|
+
write=5.0,
|
1999
|
+
pool=5.0
|
2000
|
+
),
|
2001
|
+
transport=transport,
|
2002
|
+
limits=limits
|
2003
|
+
) as client:
|
2004
|
+
# Run ALL trajectories in parallel (no batching)
|
2005
|
+
all_tasks = []
|
2006
|
+
|
2007
|
+
# Create all tasks at once
|
2008
|
+
for i, task_instance in enumerate(easy_task_instances):
|
2009
|
+
agent = CrafterReActAgent(config.model_name, max_turns=config.max_turns, verbose=False)
|
2010
|
+
all_tasks.append(
|
2011
|
+
run_single_episode(client, agent, task_instance, i + 1, str(traces_dir),
|
2012
|
+
config, episode_pbar_dict, all_achievements, all_invalid_actions, all_total_actions,
|
2013
|
+
experiment_id=experiment_id)
|
2014
|
+
)
|
2015
|
+
|
2016
|
+
# Run all episodes in parallel
|
2017
|
+
all_results = await asyncio.gather(*all_tasks)
|
2018
|
+
|
2019
|
+
results = all_results
|
2020
|
+
|
2021
|
+
# Close all progress bars
|
2022
|
+
for pbar in episode_pbar_dict.values():
|
2023
|
+
pbar.close()
|
2024
|
+
|
2025
|
+
# Filter out error results and exception-terminated episodes
|
2026
|
+
valid_results = []
|
2027
|
+
all_termination_reasons = []
|
2028
|
+
excluded_count = 0
|
2029
|
+
for r in results:
|
2030
|
+
termination_reason = r.get("termination_reason", "unknown")
|
2031
|
+
all_termination_reasons.append(termination_reason)
|
2032
|
+
|
2033
|
+
# Exclude episodes that errored or ended with exceptions
|
2034
|
+
if r.get("error", False) or "exception" in termination_reason:
|
2035
|
+
excluded_count += 1
|
2036
|
+
print(f" ⚠️ Excluding episode from stats due to: {termination_reason}")
|
2037
|
+
continue
|
2038
|
+
valid_results.append(r)
|
2039
|
+
|
2040
|
+
if excluded_count > 0:
|
2041
|
+
print(f" 📊 Excluded {excluded_count} episodes from aggregate statistics due to errors/exceptions")
|
2042
|
+
|
2043
|
+
if not valid_results:
|
2044
|
+
return {
|
2045
|
+
"eval_metrics": [],
|
2046
|
+
"mean_eval_metric": 0.0,
|
2047
|
+
"mean_rubric": {},
|
2048
|
+
"num_episodes": 0,
|
2049
|
+
}
|
2050
|
+
|
2051
|
+
# Extract eval metrics and rubrics
|
2052
|
+
eval_metrics = [r["eval_metric"] for r in valid_results]
|
2053
|
+
mean_eval_metric = sum(eval_metrics) / len(eval_metrics)
|
2054
|
+
|
2055
|
+
# --- Rollout length statistics ---
|
2056
|
+
rollout_lengths = [r["rollout_length"] for r in valid_results]
|
2057
|
+
sorted_lengths = sorted(rollout_lengths)
|
2058
|
+
n_lengths = len(sorted_lengths)
|
2059
|
+
# Median (Q2)
|
2060
|
+
if n_lengths % 2 == 1:
|
2061
|
+
q2_rollout = sorted_lengths[n_lengths // 2]
|
2062
|
+
else:
|
2063
|
+
q2_rollout = (sorted_lengths[n_lengths // 2 - 1] + sorted_lengths[n_lengths // 2]) / 2
|
2064
|
+
# 90th percentile (P90)
|
2065
|
+
p90_index = int(0.9 * (n_lengths - 1))
|
2066
|
+
p90_rollout = sorted_lengths[p90_index]
|
2067
|
+
max_rollout = sorted_lengths[-1]
|
2068
|
+
|
2069
|
+
# Calculate mean rubric values
|
2070
|
+
all_rubric_keys = set()
|
2071
|
+
for r in valid_results:
|
2072
|
+
all_rubric_keys.update(r["rubric"].keys())
|
2073
|
+
|
2074
|
+
mean_rubric = {}
|
2075
|
+
for key in all_rubric_keys:
|
2076
|
+
values = [r["rubric"].get(key, 0.0) for r in valid_results]
|
2077
|
+
mean_rubric[key] = sum(values) / len(values)
|
2078
|
+
|
2079
|
+
# Calculate shaped reward (training rubric)
|
2080
|
+
# Count total achievements across all episodes
|
2081
|
+
achievement_counts = {}
|
2082
|
+
unique_achievements_per_trajectory = []
|
2083
|
+
all_unique_achievements = set()
|
2084
|
+
|
2085
|
+
for result in valid_results:
|
2086
|
+
achievements = result.get("achievements", {})
|
2087
|
+
trajectory_achievements = set()
|
2088
|
+
for achievement, unlocked in achievements.items():
|
2089
|
+
if unlocked:
|
2090
|
+
achievement_counts[achievement] = achievement_counts.get(achievement, 0) + 1
|
2091
|
+
trajectory_achievements.add(achievement)
|
2092
|
+
all_unique_achievements.add(achievement)
|
2093
|
+
unique_achievements_per_trajectory.append(trajectory_achievements)
|
2094
|
+
|
2095
|
+
# Calculate shaped reward using the counts
|
2096
|
+
shaped_reward_data = calculate_shaped_reward(achievement_counts)
|
2097
|
+
|
2098
|
+
# Calculate unique achievements by N trajectories
|
2099
|
+
unique_achievements_by_n = {}
|
2100
|
+
for n in range(1, len(valid_results) + 1):
|
2101
|
+
unique_at_n = set()
|
2102
|
+
for i in range(n):
|
2103
|
+
unique_at_n.update(unique_achievements_per_trajectory[i])
|
2104
|
+
unique_achievements_by_n[n] = len(unique_at_n)
|
2105
|
+
|
2106
|
+
# Create training rubric (normalized shaped reward components)
|
2107
|
+
training_rubric = {}
|
2108
|
+
total_episodes = len(valid_results)
|
2109
|
+
if shaped_reward_data["breakdown"]:
|
2110
|
+
for achievement, data in shaped_reward_data["breakdown"].items():
|
2111
|
+
# Normalize by number of episodes for comparison
|
2112
|
+
training_rubric[achievement] = data["contribution"] / total_episodes
|
2113
|
+
|
2114
|
+
# Get newly created trace files
|
2115
|
+
final_trace_files = set(traces_dir.glob("*.json"))
|
2116
|
+
new_trace_files = sorted(list(final_trace_files - initial_trace_files))
|
2117
|
+
|
2118
|
+
return {
|
2119
|
+
"eval_metrics": eval_metrics,
|
2120
|
+
"mean_eval_metric": mean_eval_metric,
|
2121
|
+
"mean_rubric": mean_rubric,
|
2122
|
+
"achievement_counts": achievement_counts,
|
2123
|
+
"shaped_reward_data": shaped_reward_data,
|
2124
|
+
"training_rubric": training_rubric,
|
2125
|
+
"unique_achievements_per_trajectory": unique_achievements_per_trajectory,
|
2126
|
+
"all_unique_achievements": all_unique_achievements,
|
2127
|
+
"unique_achievements_by_n": unique_achievements_by_n,
|
2128
|
+
"num_episodes": len(valid_results),
|
2129
|
+
"q2_rollout": q2_rollout,
|
2130
|
+
"p90_rollout": p90_rollout,
|
2131
|
+
"max_rollout": max_rollout,
|
2132
|
+
"new_trace_files": new_trace_files,
|
2133
|
+
"all_achievements": dict(all_achievements), # Add the tracked achievements
|
2134
|
+
"all_invalid_actions": dict(all_invalid_actions), # Add the tracked invalid actions
|
2135
|
+
"all_total_actions": dict(all_total_actions), # Add the total actions
|
2136
|
+
"termination_reasons": all_termination_reasons, # Add termination reasons
|
2137
|
+
"raw_results": all_results, # Add raw results for timing analysis
|
2138
|
+
"experiment_context": experiment_context, # Add experiment context for tracking
|
2139
|
+
}
|
2140
|
+
|
2141
|
+
|
2142
|
+
async def main():
|
2143
|
+
"""Run Crafter evaluation."""
|
2144
|
+
# Record start time for trace filtering
|
2145
|
+
import time
|
2146
|
+
config._run_start_time = time.time()
|
2147
|
+
|
2148
|
+
# Configure logging to reduce verbosity
|
2149
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
2150
|
+
logging.getLogger("google_genai").setLevel(logging.ERROR)
|
2151
|
+
logging.getLogger("google.generativeai").setLevel(logging.ERROR)
|
2152
|
+
logging.getLogger("google_genai.models").setLevel(logging.ERROR)
|
2153
|
+
logging.getLogger("google_genai.types").setLevel(logging.ERROR)
|
2154
|
+
logging.getLogger("google_genai._api_client").setLevel(logging.ERROR)
|
2155
|
+
|
2156
|
+
# Disable synth_ai LM debug logs
|
2157
|
+
logging.getLogger("synth_ai.lm.provider_support.openai").setLevel(logging.WARNING)
|
2158
|
+
logging.getLogger("synth_ai.lm").setLevel(logging.WARNING)
|
2159
|
+
|
2160
|
+
# Suppress worker/trajectory logs from other systems
|
2161
|
+
logging.getLogger("synth_ai").setLevel(logging.WARNING)
|
2162
|
+
logging.getLogger("synth_ai.environments").setLevel(logging.WARNING)
|
2163
|
+
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
2164
|
+
|
2165
|
+
# Suppress Langfuse client warnings
|
2166
|
+
logging.getLogger("langfuse").setLevel(logging.ERROR)
|
2167
|
+
logging.getLogger("langfuse.client").setLevel(logging.ERROR)
|
2168
|
+
|
2169
|
+
# Set root logger to WARNING to suppress debug prints
|
2170
|
+
logging.getLogger().setLevel(logging.WARNING)
|
2171
|
+
|
2172
|
+
print(f"🎮 Crafter ReAct Agent Evaluation")
|
2173
|
+
print(f"Model: {config.model_name}")
|
2174
|
+
print(f"Service: {config.service_base_url}")
|
2175
|
+
print(f"Instances: {config.num_instances}")
|
2176
|
+
print(f"Max Turns: {config.max_turns}")
|
2177
|
+
print(f"Difficulty: {config.difficulty}")
|
2178
|
+
print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
2179
|
+
print("=" * 50)
|
2180
|
+
|
2181
|
+
# Test service health
|
2182
|
+
# Use same connection pool settings for health check
|
2183
|
+
limits = httpx.Limits(
|
2184
|
+
max_keepalive_connections=30,
|
2185
|
+
max_connections=100,
|
2186
|
+
keepalive_expiry=60.0
|
2187
|
+
)
|
2188
|
+
transport = httpx.AsyncHTTPTransport(
|
2189
|
+
http2=True,
|
2190
|
+
limits=limits,
|
2191
|
+
retries=3 # Transport-level retries
|
2192
|
+
)
|
2193
|
+
async with AsyncClient(
|
2194
|
+
base_url=config.service_base_url,
|
2195
|
+
timeout=httpx.Timeout(
|
2196
|
+
connect=5.0,
|
2197
|
+
read=HTTP_TIMEOUT,
|
2198
|
+
write=5.0,
|
2199
|
+
pool=5.0
|
2200
|
+
),
|
2201
|
+
transport=transport,
|
2202
|
+
limits=limits
|
2203
|
+
) as client:
|
2204
|
+
try:
|
2205
|
+
health_resp = await retry_http_request(client, "GET", "/health")
|
2206
|
+
health_data = health_resp.json()
|
2207
|
+
|
2208
|
+
if "CrafterClassic" not in health_data.get("supported_environments", []):
|
2209
|
+
print("❌ CrafterClassic not available on service")
|
2210
|
+
return
|
2211
|
+
|
2212
|
+
print("✅ Service health check passed")
|
2213
|
+
|
2214
|
+
except Exception as e:
|
2215
|
+
print(f"❌ Service health check failed after retries: {type(e).__name__}: {str(e)[:100]}")
|
2216
|
+
return
|
2217
|
+
|
2218
|
+
# Run evaluation
|
2219
|
+
try:
|
2220
|
+
results = await evaluate_crafter_batch()
|
2221
|
+
|
2222
|
+
print("\n" + "=" * 80)
|
2223
|
+
print("🏆 CRAFTER EVALUATION RESULTS")
|
2224
|
+
print("=" * 80)
|
2225
|
+
|
2226
|
+
# Check if any episodes were excluded
|
2227
|
+
total_episodes = len(results.get("termination_reasons", []))
|
2228
|
+
valid_episodes = results["num_episodes"]
|
2229
|
+
if total_episodes > valid_episodes:
|
2230
|
+
excluded = total_episodes - valid_episodes
|
2231
|
+
print(f"📝 Note: {excluded} episode(s) excluded from statistics due to errors/exceptions")
|
2232
|
+
print("=" * 80)
|
2233
|
+
|
2234
|
+
# Calculate key metrics
|
2235
|
+
achievements_per_trajectory = [
|
2236
|
+
len(achievements)
|
2237
|
+
for achievements in results.get("unique_achievements_per_trajectory", [])
|
2238
|
+
]
|
2239
|
+
avg_achievements = (
|
2240
|
+
sum(achievements_per_trajectory) / len(achievements_per_trajectory)
|
2241
|
+
if achievements_per_trajectory
|
2242
|
+
else 0.0
|
2243
|
+
)
|
2244
|
+
total_unique = len(results.get("all_unique_achievements", set()))
|
2245
|
+
shaped_reward = results.get("shaped_reward_data", {}).get("total_shaped_reward", 0.0)
|
2246
|
+
mean_k_score = (
|
2247
|
+
shaped_reward / results["num_episodes"] if results["num_episodes"] > 0 else 0.0
|
2248
|
+
)
|
2249
|
+
achievement_counts = results.get("achievement_counts", {})
|
2250
|
+
|
2251
|
+
# Assessment
|
2252
|
+
if results["mean_eval_metric"] >= 3.0:
|
2253
|
+
assessment = "🎉 Excellent"
|
2254
|
+
elif results["mean_eval_metric"] >= 1.0:
|
2255
|
+
assessment = "✅ Good"
|
2256
|
+
elif results["mean_eval_metric"] >= 0.5:
|
2257
|
+
assessment = "⚠️ Moderate"
|
2258
|
+
else:
|
2259
|
+
assessment = "📈 Learning"
|
2260
|
+
|
2261
|
+
# Create dense results table
|
2262
|
+
print(f"│ Metric │ Value │")
|
2263
|
+
print(f"├───────────────────────────┼───────────────────────────────────────────┤")
|
2264
|
+
print(f"│ Model │ {config.model_name:<41} │")
|
2265
|
+
print(f"│ Episodes │ {results['num_episodes']:<41} │")
|
2266
|
+
print(f"│ Mean Score │ {results['mean_eval_metric']:<41.2f} │")
|
2267
|
+
print(f"│ Avg Achievements/Episode │ {avg_achievements:<41.2f} │")
|
2268
|
+
print(f"│ Unique Achievements │ {total_unique:<41} │")
|
2269
|
+
print(f"│ Shaped Reward (Total) │ {shaped_reward:<41.3f} │")
|
2270
|
+
print(f"│ Mean K-Score/Episode │ {mean_k_score:<41.3f} │")
|
2271
|
+
print(f"│ Q2 Rollout Length │ {results.get('q2_rollout', 0):<41} │")
|
2272
|
+
print(f"│ Max Rollout Length │ {results.get('max_rollout', 0):<41} │")
|
2273
|
+
|
2274
|
+
# Show unlocked achievements
|
2275
|
+
all_unique = results.get("all_unique_achievements", set())
|
2276
|
+
if all_unique:
|
2277
|
+
achievements_str = ', '.join(sorted(all_unique))
|
2278
|
+
if len(achievements_str) > 41:
|
2279
|
+
achievements_str = achievements_str[:38] + "..."
|
2280
|
+
print(f"│ Unlocked Achievements │ {achievements_str:<41} │")
|
2281
|
+
|
2282
|
+
print(f"└───────────────────────────┴───────────────────────────────────────────┘")
|
2283
|
+
|
2284
|
+
# Print experiment info
|
2285
|
+
if 'experiment_context' in results:
|
2286
|
+
exp_ctx = results['experiment_context']
|
2287
|
+
print(f"\n{'='*80}")
|
2288
|
+
print(f"📊 EXPERIMENT SAVED TO DUCKDB")
|
2289
|
+
print(f"{'='*80}")
|
2290
|
+
print(f"🧪 Name: {exp_ctx['experiment_name']}")
|
2291
|
+
print(f"📋 ID: {exp_ctx['experiment_id']}")
|
2292
|
+
print(f"🌿 Git: {exp_ctx['git_branch']} @ {exp_ctx['git_commit'][:8]}")
|
2293
|
+
print(f"\n🔍 Query this experiment:")
|
2294
|
+
print(f" python -m synth_ai.tui.cli.query_experiments -e {exp_ctx['experiment_id'][:8]}")
|
2295
|
+
print(f"\n📊 Or view all experiments:")
|
2296
|
+
print(f" python -m synth_ai.tui.cli.query_experiments")
|
2297
|
+
print(f"{'='*80}")
|
2298
|
+
|
2299
|
+
# Print trace file sizes for newly created files only
|
2300
|
+
new_trace_files = results.get("new_trace_files", [])
|
2301
|
+
if new_trace_files:
|
2302
|
+
total_size_mb = 0
|
2303
|
+
for trace_file in new_trace_files:
|
2304
|
+
size_bytes = trace_file.stat().st_size
|
2305
|
+
size_mb = size_bytes / (1024 * 1024)
|
2306
|
+
total_size_mb += size_mb
|
2307
|
+
|
2308
|
+
# Skip verbose trace analysis (moved to summary tables only)
|
2309
|
+
|
2310
|
+
# Check if verbose output is requested
|
2311
|
+
if hasattr(config, 'verbose_output') and config.verbose_output:
|
2312
|
+
# Display achievements table
|
2313
|
+
all_achievements = results.get("all_achievements", {})
|
2314
|
+
if all_achievements:
|
2315
|
+
print_achievements_table(all_achievements, results['num_episodes'])
|
2316
|
+
|
2317
|
+
# Display invalid actions table
|
2318
|
+
invalid_actions = results.get("all_invalid_actions", {})
|
2319
|
+
total_actions = results.get("all_total_actions", {})
|
2320
|
+
if invalid_actions:
|
2321
|
+
print_invalid_actions_table(invalid_actions, total_actions)
|
2322
|
+
|
2323
|
+
# Display termination breakdown
|
2324
|
+
termination_reasons = results.get("termination_reasons", [])
|
2325
|
+
if termination_reasons:
|
2326
|
+
print_termination_breakdown(termination_reasons)
|
2327
|
+
|
2328
|
+
# Display timing analysis
|
2329
|
+
raw_results = results.get("raw_results", [])
|
2330
|
+
if raw_results:
|
2331
|
+
print_timing_analysis(raw_results)
|
2332
|
+
else:
|
2333
|
+
# Display condensed summary (default)
|
2334
|
+
all_achievements = results.get("all_achievements", {})
|
2335
|
+
invalid_actions = results.get("all_invalid_actions", {})
|
2336
|
+
total_actions = results.get("all_total_actions", {})
|
2337
|
+
termination_reasons = results.get("termination_reasons", [])
|
2338
|
+
raw_results = results.get("raw_results", [])
|
2339
|
+
|
2340
|
+
print_condensed_summary(all_achievements, invalid_actions, total_actions,
|
2341
|
+
termination_reasons, raw_results, results['num_episodes'])
|
2342
|
+
|
2343
|
+
except Exception as e:
|
2344
|
+
print(f"❌ Evaluation failed: {e}")
|
2345
|
+
import traceback
|
2346
|
+
traceback.print_exc()
|
2347
|
+
|
2348
|
+
|
2349
|
+
if __name__ == "__main__":
|
2350
|
+
# Parse command line arguments
|
2351
|
+
parser = argparse.ArgumentParser(description="Run Crafter ReAct Agent Evaluation")
|
2352
|
+
parser.add_argument("--config", "-c", type=str, help="Path to TOML configuration file")
|
2353
|
+
parser.add_argument("--model", "-m", type=str, help="Model name (overrides config)")
|
2354
|
+
parser.add_argument("--episodes", "-e", type=int, help="Number of episodes (overrides config)")
|
2355
|
+
parser.add_argument(
|
2356
|
+
"--max-turns", "-t", type=int, help="Maximum turns per episode (overrides config)"
|
2357
|
+
)
|
2358
|
+
parser.add_argument("--difficulty", "-d", type=str, help="Difficulty level (overrides config)")
|
2359
|
+
parser.add_argument("--analyze-traces", action="store_true", help="Analyze trace files after evaluation")
|
2360
|
+
parser.add_argument("--evaluate-traces", action="store_true", help="Run trace evaluation scoring after episodes")
|
2361
|
+
parser.add_argument("--verbose", action="store_true", help="Use verbose output format (detailed tables and statistics)")
|
2362
|
+
|
2363
|
+
# Custom OpenAI endpoint support
|
2364
|
+
parser.add_argument(
|
2365
|
+
"--openai-base-url",
|
2366
|
+
type=str,
|
2367
|
+
help="Custom OpenAI-compatible base URL (e.g., https://lora-service.modal.run)",
|
2368
|
+
)
|
2369
|
+
parser.add_argument(
|
2370
|
+
"--openai-api-key",
|
2371
|
+
type=str,
|
2372
|
+
default="dummy",
|
2373
|
+
help="API key for custom endpoint (default: 'dummy')",
|
2374
|
+
)
|
2375
|
+
|
2376
|
+
args = parser.parse_args()
|
2377
|
+
|
2378
|
+
# Load configuration
|
2379
|
+
if args.config:
|
2380
|
+
config = CrafterConfig(args.config)
|
2381
|
+
else:
|
2382
|
+
# Try to load default config
|
2383
|
+
default_config_path = (
|
2384
|
+
Path(__file__).parent.parent.parent.parent / "evals" / "configs" / "crafter.toml"
|
2385
|
+
)
|
2386
|
+
if default_config_path.exists():
|
2387
|
+
config = CrafterConfig(str(default_config_path))
|
2388
|
+
else:
|
2389
|
+
config = CrafterConfig()
|
2390
|
+
|
2391
|
+
# Override with command line arguments
|
2392
|
+
if args.model:
|
2393
|
+
config.model_name = args.model
|
2394
|
+
if args.episodes:
|
2395
|
+
config.num_instances = args.episodes
|
2396
|
+
if args.max_turns:
|
2397
|
+
config.max_turns = args.max_turns
|
2398
|
+
if args.difficulty:
|
2399
|
+
config.difficulty = args.difficulty
|
2400
|
+
if args.analyze_traces:
|
2401
|
+
config.analyze_traces = True
|
2402
|
+
if args.evaluate_traces:
|
2403
|
+
config.evaluate_traces = True
|
2404
|
+
if args.verbose:
|
2405
|
+
config.verbose_output = True
|
2406
|
+
|
2407
|
+
# Configure custom OpenAI endpoint if provided
|
2408
|
+
if args.openai_base_url:
|
2409
|
+
config.set_custom_endpoint(args.openai_base_url, args.openai_api_key)
|
2410
|
+
|
2411
|
+
# Fail fast if model_name still missing
|
2412
|
+
if not config.model_name:
|
2413
|
+
raise ValueError(
|
2414
|
+
"CrafterConfig: 'model_name' must be specified in the TOML config or via --model CLI argument; no fallback default."
|
2415
|
+
)
|
2416
|
+
|
2417
|
+
try:
|
2418
|
+
asyncio.run(main())
|
2419
|
+
except KeyboardInterrupt:
|
2420
|
+
print("\nInterrupted by user")
|
2421
|
+
except Exception as e:
|
2422
|
+
print(f"❌ Error: {e}")
|
2423
|
+
sys.exit(1)
|
2424
|
+
|
2425
|
+
# Run trace evaluation if requested
|
2426
|
+
if config.analyze_traces or hasattr(config, 'evaluate_traces') and config.evaluate_traces:
|
2427
|
+
print("\n" + "=" * 80)
|
2428
|
+
print("📊 TRACE EVALUATION - DEPRECATED")
|
2429
|
+
print("=" * 80)
|
2430
|
+
print("⚠️ JSON trace files are no longer generated.")
|
2431
|
+
print("All trace data is now stored in DuckDB (crafter_traces.duckdb)")
|
2432
|
+
print("\nTo analyze traces:")
|
2433
|
+
print("1. Use the DuckDB analytics summary shown above")
|
2434
|
+
print("2. Query the database directly using DuckDBTraceManager")
|
2435
|
+
print("3. Use filter_traces_sft_duckdb.py to extract training data")
|
2436
|
+
else:
|
2437
|
+
# Original trace evaluation code (now unreachable)
|
2438
|
+
pass
|
2439
|
+
|
2440
|
+
# Skip the old trace evaluation code entirely
|
2441
|
+
if False: # Never execute old trace evaluation
|
2442
|
+
try:
|
2443
|
+
# Add current directory to Python path to import trace_eval
|
2444
|
+
import sys
|
2445
|
+
from pathlib import Path
|
2446
|
+
current_dir = Path(__file__).parent
|
2447
|
+
if str(current_dir) not in sys.path:
|
2448
|
+
sys.path.insert(0, str(current_dir))
|
2449
|
+
|
2450
|
+
from trace_eval import evaluate_all_traces, print_evaluation_summary
|
2451
|
+
|
2452
|
+
trace_dir = current_dir / "traces"
|
2453
|
+
if trace_dir.exists():
|
2454
|
+
# Find trace files created during this run
|
2455
|
+
recent_traces = []
|
2456
|
+
|
2457
|
+
# Use run start time if available, otherwise fall back to last 60 seconds
|
2458
|
+
import time
|
2459
|
+
if hasattr(config, '_run_start_time'):
|
2460
|
+
start_time = config._run_start_time
|
2461
|
+
else:
|
2462
|
+
start_time = time.time() - 60 # Only traces from last minute
|
2463
|
+
|
2464
|
+
for trace_file in trace_dir.glob("*.json"):
|
2465
|
+
# Check if file was created after start time
|
2466
|
+
if trace_file.stat().st_mtime >= start_time:
|
2467
|
+
recent_traces.append(trace_file)
|
2468
|
+
|
2469
|
+
if recent_traces:
|
2470
|
+
print(f"Evaluating {len(recent_traces)} trace files from this run...")
|
2471
|
+
results = []
|
2472
|
+
for trace_file in recent_traces:
|
2473
|
+
from trace_eval import evaluate_trace, print_trace_evaluation
|
2474
|
+
result = evaluate_trace(trace_file)
|
2475
|
+
results.append(result)
|
2476
|
+
|
2477
|
+
# Sort by score
|
2478
|
+
results.sort(key=lambda x: x['total_score'], reverse=True)
|
2479
|
+
|
2480
|
+
# Check if verbose output is requested
|
2481
|
+
if hasattr(config, 'verbose_output') and config.verbose_output:
|
2482
|
+
# Show detailed evaluation only if not too many traces
|
2483
|
+
if len(results) <= 5:
|
2484
|
+
for result in results:
|
2485
|
+
print_trace_evaluation(result)
|
2486
|
+
|
2487
|
+
# Always show summary
|
2488
|
+
print_evaluation_summary(results)
|
2489
|
+
else:
|
2490
|
+
# Show only condensed trace summary (default)
|
2491
|
+
if results:
|
2492
|
+
print("\n" + "─" * 80)
|
2493
|
+
print("TRACE EVALUATION")
|
2494
|
+
print("─" * 80)
|
2495
|
+
avg_score = sum(r['total_score'] for r in results) / len(results)
|
2496
|
+
best_score = max(r['total_score'] for r in results)
|
2497
|
+
worst_score = min(r['total_score'] for r in results)
|
2498
|
+
print(f"Traces: {len(results)} | Avg: {avg_score:.2f} | Best: {best_score:.2f} | Worst: {worst_score:.2f}")
|
2499
|
+
print("─" * 80)
|
2500
|
+
|
2501
|
+
# Also save to file for debugging
|
2502
|
+
eval_file = current_dir / "last_evaluation.txt"
|
2503
|
+
with open(eval_file, 'w') as f:
|
2504
|
+
f.write(f"Evaluation of {len(results)} traces\n")
|
2505
|
+
f.write("="*60 + "\n")
|
2506
|
+
for result in results:
|
2507
|
+
f.write(f"\nTrace: {result['trace_file']}\n")
|
2508
|
+
f.write(f"Score: {result['total_score']:.2f}\n")
|
2509
|
+
f.write(f"Trajectory: {result['trajectory']}\n")
|
2510
|
+
print(f"\n📝 Evaluation saved to: {eval_file}")
|
2511
|
+
else:
|
2512
|
+
print("No recent trace files found to evaluate.")
|
2513
|
+
else:
|
2514
|
+
print(f"Trace directory not found: {trace_dir}")
|
2515
|
+
|
2516
|
+
except ImportError:
|
2517
|
+
print("⚠️ trace_eval module not found. Skipping trace evaluation.")
|
2518
|
+
except Exception as e:
|
2519
|
+
print(f"⚠️ Error during trace evaluation: {e}")
|
2520
|
+
|
2521
|
+
# Show DuckDB analytics summary if available
|
2522
|
+
try:
|
2523
|
+
from synth_ai.tracing_v2.duckdb.manager import DuckDBTraceManager
|
2524
|
+
print("\n" + "─" * 80)
|
2525
|
+
print("DUCKDB TRACE ANALYTICS")
|
2526
|
+
print("─" * 80)
|
2527
|
+
|
2528
|
+
with DuckDBTraceManager("synth_ai/traces/crafter_traces.duckdb") as db:
|
2529
|
+
# Get model usage stats
|
2530
|
+
model_stats = db.get_model_usage()
|
2531
|
+
if not model_stats.empty:
|
2532
|
+
print("\n📊 Model Usage:")
|
2533
|
+
for _, row in model_stats.iterrows():
|
2534
|
+
print(f" • {row['model_name']}: {row['call_count']} calls, "
|
2535
|
+
f"{row['total_tokens']} tokens, ${row['total_cost']:.4f}")
|
2536
|
+
|
2537
|
+
# Get session summary
|
2538
|
+
sessions = db.get_session_summary()
|
2539
|
+
if not sessions.empty:
|
2540
|
+
print(f"\n📈 Sessions: {len(sessions)} total")
|
2541
|
+
print(f" • Avg events per session: {sessions['num_events'].mean():.1f}")
|
2542
|
+
print(f" • Total cost: ${sessions['total_cost'].sum():.4f}")
|
2543
|
+
|
2544
|
+
print("─" * 80)
|
2545
|
+
print(f"💾 Trace data stored in: synth_ai/traces/crafter_traces.duckdb")
|
2546
|
+
except Exception as e:
|
2547
|
+
# Silently skip if DuckDB not available or no data
|
2548
|
+
pass
|
2549
|
+
|
2550
|
+
# Normal exit (allow cleanup and final output)
|
2551
|
+
sys.exit(0)
|