synth-ai 0.2.2.dev0__py3-none-any.whl → 0.2.4.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/cli/__init__.py +66 -0
- synth_ai/cli/balance.py +205 -0
- synth_ai/cli/calc.py +70 -0
- synth_ai/cli/demo.py +74 -0
- synth_ai/{cli.py → cli/legacy_root_backup.py} +60 -15
- synth_ai/cli/man.py +103 -0
- synth_ai/cli/recent.py +126 -0
- synth_ai/cli/root.py +184 -0
- synth_ai/cli/status.py +126 -0
- synth_ai/cli/traces.py +136 -0
- synth_ai/cli/watch.py +508 -0
- synth_ai/config/base_url.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +252 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_duckdb_v2_backup.py +413 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +760 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_synth.py +34 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth.py +1740 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth_v2_backup.py +1318 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_duckdb_v2_backup.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v2_backup.py +1352 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +4 -4
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/test_crafter_react_agent_openai_v2_backup.py +2551 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +1 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/old/traces/session_crafter_episode_16_15227b68-2906-416f-acc4-d6a9b4fa5828_20250725_001154.json +1363 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +3 -3
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/crafter_custom/environment.py +1 -1
- synth_ai/environments/examples/enron/dataset/corbt___enron_emails_sample_questions/default/0.0.0/293c9fe8170037e01cc9cf5834e0cd5ef6f1a6bb/dataset_info.json +1 -0
- synth_ai/environments/examples/nethack/helpers/achievements.json +64 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +1 -1
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +5 -5
- synth_ai/environments/examples/red/units/test_movement_debug.py +2 -2
- synth_ai/environments/examples/red/units/test_retry_movement.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/available_envs.json +122 -0
- synth_ai/environments/examples/sokoban/verified_puzzles.json +54987 -0
- synth_ai/environments/service/core_routes.py +1 -1
- synth_ai/experimental/synth_oss.py +446 -0
- synth_ai/learning/core.py +21 -0
- synth_ai/learning/gateway.py +4 -0
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/mipro.py +8 -0
- synth_ai/lm/__init__.py +3 -0
- synth_ai/lm/core/main.py +4 -0
- synth_ai/lm/core/main_v3.py +238 -122
- synth_ai/lm/core/vendor_clients.py +4 -0
- synth_ai/lm/provider_support/openai.py +11 -2
- synth_ai/lm/vendors/base.py +7 -0
- synth_ai/lm/vendors/openai_standard.py +339 -4
- synth_ai/lm/vendors/openai_standard_responses.py +243 -0
- synth_ai/lm/vendors/synth_client.py +155 -5
- synth_ai/lm/warmup.py +54 -17
- synth_ai/tracing/__init__.py +18 -0
- synth_ai/tracing_v1/__init__.py +29 -14
- synth_ai/tracing_v3/__init__.py +2 -2
- synth_ai/tracing_v3/abstractions.py +62 -17
- synth_ai/tracing_v3/config.py +13 -7
- synth_ai/tracing_v3/db_config.py +6 -6
- synth_ai/tracing_v3/hooks.py +1 -1
- synth_ai/tracing_v3/llm_call_record_helpers.py +350 -0
- synth_ai/tracing_v3/lm_call_record_abstractions.py +257 -0
- synth_ai/tracing_v3/session_tracer.py +5 -5
- synth_ai/tracing_v3/tests/test_concurrent_operations.py +1 -1
- synth_ai/tracing_v3/tests/test_llm_call_records.py +672 -0
- synth_ai/tracing_v3/tests/test_session_tracer.py +43 -9
- synth_ai/tracing_v3/tests/test_turso_manager.py +1 -1
- synth_ai/tracing_v3/turso/manager.py +18 -11
- synth_ai/tracing_v3/turso/models.py +1 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/dashboard.py +329 -0
- synth_ai/v0/tracing/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/base_client.py +3 -3
- synth_ai/{tracing → v0/tracing}/client_manager.py +1 -1
- synth_ai/{tracing → v0/tracing}/context.py +1 -1
- synth_ai/{tracing → v0/tracing}/decorators.py +11 -11
- synth_ai/v0/tracing/events/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/events/manage.py +4 -4
- synth_ai/{tracing → v0/tracing}/events/scope.py +6 -6
- synth_ai/{tracing → v0/tracing}/events/store.py +3 -3
- synth_ai/{tracing → v0/tracing}/immediate_client.py +6 -6
- synth_ai/{tracing → v0/tracing}/log_client_base.py +2 -2
- synth_ai/{tracing → v0/tracing}/retry_queue.py +3 -3
- synth_ai/{tracing → v0/tracing}/trackers.py +2 -2
- synth_ai/{tracing → v0/tracing}/upload.py +4 -4
- synth_ai/v0/tracing_v1/__init__.py +16 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/base_client.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/client_manager.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/context.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/decorators.py +11 -11
- synth_ai/v0/tracing_v1/events/__init__.py +0 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/manage.py +4 -4
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/scope.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/store.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/immediate_client.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/log_client_base.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/retry_queue.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/trackers.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/upload.py +4 -4
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/METADATA +100 -5
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/RECORD +115 -75
- /synth_ai/{tracing/events/__init__.py → compound/cais.py} +0 -0
- /synth_ai/{tracing_v1/events/__init__.py → environments/examples/crafter_classic/debug_translation.py} +0 -0
- /synth_ai/{tracing → v0/tracing}/abstractions.py +0 -0
- /synth_ai/{tracing → v0/tracing}/config.py +0 -0
- /synth_ai/{tracing → v0/tracing}/local.py +0 -0
- /synth_ai/{tracing → v0/tracing}/utils.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/abstractions.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/config.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/local.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/utils.py +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1318 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Test script to run ReAct agents against Crafter environment using LM class with Synth backend.
|
4
|
+
This demonstrates using the LM class with Synth models through native integration.
|
5
|
+
|
6
|
+
This version properly handles the provider routing to use Synth/Modal endpoints.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import asyncio
|
10
|
+
import json
|
11
|
+
import uuid
|
12
|
+
import math
|
13
|
+
import argparse
|
14
|
+
import toml
|
15
|
+
import logging
|
16
|
+
import time
|
17
|
+
import functools
|
18
|
+
from datetime import datetime
|
19
|
+
from typing import Dict, Any, Optional, List, Set, Literal
|
20
|
+
from pydantic import BaseModel, Field
|
21
|
+
from httpx import AsyncClient
|
22
|
+
import httpx
|
23
|
+
import sys
|
24
|
+
import os
|
25
|
+
from pathlib import Path
|
26
|
+
from tqdm.asyncio import tqdm_asyncio
|
27
|
+
from tqdm import tqdm
|
28
|
+
import random
|
29
|
+
import glob
|
30
|
+
from collections import defaultdict
|
31
|
+
|
32
|
+
# Configure logging to suppress noisy third-party logs when in quiet mode
|
33
|
+
def setup_logging(quiet_mode: bool = False):
|
34
|
+
"""Setup logging configuration."""
|
35
|
+
if quiet_mode:
|
36
|
+
# Suppress most third-party logging in quiet mode
|
37
|
+
logging.getLogger("httpx").setLevel(logging.ERROR)
|
38
|
+
logging.getLogger("synth_ai.tracing_v2.duckdb.manager").setLevel(logging.ERROR)
|
39
|
+
logging.getLogger("synth_ai.tracing_v2").setLevel(logging.ERROR)
|
40
|
+
logging.getLogger("duckdb").setLevel(logging.ERROR)
|
41
|
+
# Also set the root logger for synth_ai tracing to be quiet
|
42
|
+
logging.getLogger("synth_ai.tracing_v2.duckdb").setLevel(logging.ERROR)
|
43
|
+
logging.getLogger("synth_ai.tracing_v2.session_tracer").setLevel(logging.ERROR)
|
44
|
+
# Suppress httpcore as well (used by httpx)
|
45
|
+
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
46
|
+
else:
|
47
|
+
# Normal logging levels
|
48
|
+
logging.getLogger("httpx").setLevel(logging.INFO)
|
49
|
+
logging.getLogger("synth_ai.tracing_v2.duckdb.manager").setLevel(logging.INFO)
|
50
|
+
logging.getLogger("synth_ai.tracing_v2").setLevel(logging.INFO)
|
51
|
+
|
52
|
+
# Set default logging to avoid noisy logs during import
|
53
|
+
setup_logging(quiet_mode=True)
|
54
|
+
|
55
|
+
# Setup environment
|
56
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent.parent))
|
57
|
+
|
58
|
+
# Disable v1 logging to see v2 tracing clearly
|
59
|
+
os.environ["LANGFUSE_ENABLED"] = "false"
|
60
|
+
os.environ["SYNTH_LOGGING"] = "false"
|
61
|
+
|
62
|
+
import numpy as np
|
63
|
+
|
64
|
+
# Import Synth warmup utilities
|
65
|
+
from synth_ai.lm.warmup import warmup_synth_model
|
66
|
+
from synth_ai.lm.config import SynthConfig
|
67
|
+
|
68
|
+
# Import session tracer for v2 tracing
|
69
|
+
from synth_ai.tracing_v2.session_tracer import (
|
70
|
+
SessionTracer, SessionEventMarkovBlanketMessage, TimeRecord,
|
71
|
+
RuntimeEvent, EnvironmentEvent, LMCAISEvent
|
72
|
+
)
|
73
|
+
from synth_ai.tracing_v2.utils import create_experiment_context
|
74
|
+
from synth_ai.tracing_v2.duckdb.manager import DuckDBTraceManager
|
75
|
+
from synth_ai.tracing_v2.decorators import (
|
76
|
+
set_active_session_tracer, set_system_id, set_turn_number, get_config
|
77
|
+
)
|
78
|
+
from synth_ai.environments.examples.crafter_classic.trace_hooks import CRAFTER_HOOKS
|
79
|
+
from datetime import datetime
|
80
|
+
|
81
|
+
# Import LM components
|
82
|
+
from synth_ai.lm.core.main_v2 import LM
|
83
|
+
|
84
|
+
# Import Crafter hooks
|
85
|
+
try:
|
86
|
+
from synth_ai.environments.examples.crafter_classic.trace_hooks import CRAFTER_HOOKS
|
87
|
+
print(f"✅ Loaded {len(CRAFTER_HOOKS)} Crafter achievement hooks (Easy, Medium, Hard)")
|
88
|
+
except ImportError:
|
89
|
+
print("Warning: Could not import CRAFTER_HOOKS")
|
90
|
+
CRAFTER_HOOKS = []
|
91
|
+
|
92
|
+
# Configuration constants
|
93
|
+
HTTP_TIMEOUT = 30.0 # Increased from 10.0 for better handling of concurrent load and LM response times
|
94
|
+
MAX_RETRIES = 3
|
95
|
+
RETRY_DELAY = 1.0
|
96
|
+
|
97
|
+
|
98
|
+
def cleanup_old_files():
|
99
|
+
"""Clean up old trace files and result files to keep directory clean."""
|
100
|
+
# Remove old JSON result files (keep only the latest 5)
|
101
|
+
result_files = glob.glob("crafter_lm_synth_results_*.json")
|
102
|
+
if len(result_files) > 5:
|
103
|
+
# Sort by modification time and keep only the latest 5
|
104
|
+
result_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
105
|
+
for old_file in result_files[5:]:
|
106
|
+
try:
|
107
|
+
os.remove(old_file)
|
108
|
+
print(f"🗑️ Cleaned up old result file: {old_file}")
|
109
|
+
except OSError:
|
110
|
+
pass
|
111
|
+
|
112
|
+
# Remove old JSON trace files (keep only DuckDB)
|
113
|
+
trace_dirs = ["traces_v2_synth", "traces_v2_lm_synth"]
|
114
|
+
for trace_dir in trace_dirs:
|
115
|
+
if os.path.exists(trace_dir):
|
116
|
+
json_files = glob.glob(f"{trace_dir}/session_*.json")
|
117
|
+
for json_file in json_files:
|
118
|
+
try:
|
119
|
+
os.remove(json_file)
|
120
|
+
print(f"🗑️ Cleaned up old trace file: {json_file}")
|
121
|
+
except OSError:
|
122
|
+
pass
|
123
|
+
|
124
|
+
|
125
|
+
def setup_synth_environment():
|
126
|
+
"""Setup environment variables for Synth/Modal endpoints."""
|
127
|
+
synth_base_url = os.getenv('SYNTH_BASE_URL') or os.getenv('MODAL_BASE_URL')
|
128
|
+
synth_api_key = os.getenv('SYNTH_API_KEY') or os.getenv('MODAL_API_KEY')
|
129
|
+
|
130
|
+
if not synth_base_url or not synth_api_key:
|
131
|
+
raise ValueError("SYNTH_BASE_URL/MODAL_BASE_URL and SYNTH_API_KEY/MODAL_API_KEY must be set")
|
132
|
+
|
133
|
+
# OpenAI client needs base URL WITH /v1 (it doesn't add it automatically)
|
134
|
+
# Ensure /v1 is present
|
135
|
+
if not synth_base_url.endswith('/v1'):
|
136
|
+
synth_base_url = synth_base_url.rstrip('/') + '/v1'
|
137
|
+
synth_base_url = synth_base_url.rstrip('/')
|
138
|
+
|
139
|
+
# Set environment variables for OpenAI client to use Synth endpoints
|
140
|
+
os.environ["OPENAI_API_BASE"] = synth_base_url
|
141
|
+
os.environ["OPENAI_BASE_URL"] = synth_base_url
|
142
|
+
os.environ["OPENAI_API_KEY"] = synth_api_key
|
143
|
+
|
144
|
+
return synth_base_url, synth_api_key
|
145
|
+
|
146
|
+
|
147
|
+
async def retry_http_request(client: AsyncClient, method: str, url: str, **kwargs) -> Any:
|
148
|
+
"""Retry HTTP requests with exponential backoff and jitter."""
|
149
|
+
last_exception = None
|
150
|
+
|
151
|
+
for attempt in range(MAX_RETRIES):
|
152
|
+
try:
|
153
|
+
if attempt > 0:
|
154
|
+
delay = min(RETRY_DELAY * (2 ** (attempt - 1)), RETRY_DELAY * 2) # Use RETRY_DELAY
|
155
|
+
jitter = random.uniform(0, 0.1 * delay)
|
156
|
+
total_delay = delay + jitter
|
157
|
+
await asyncio.sleep(total_delay)
|
158
|
+
|
159
|
+
response = await client.request(method, url, timeout=HTTP_TIMEOUT, **kwargs)
|
160
|
+
|
161
|
+
if response.status_code < 500:
|
162
|
+
return response
|
163
|
+
|
164
|
+
last_exception = Exception(f"HTTP {response.status_code}: {response.text}")
|
165
|
+
|
166
|
+
except httpx.ReadError as e:
|
167
|
+
last_exception = e
|
168
|
+
if attempt < MAX_RETRIES - 1:
|
169
|
+
read_error_delay = min(1.0 * (2 ** attempt), 5.0)
|
170
|
+
await asyncio.sleep(read_error_delay)
|
171
|
+
except Exception as e:
|
172
|
+
last_exception = e
|
173
|
+
|
174
|
+
print(f" ❌ HTTP request failed after {MAX_RETRIES} attempts: {type(last_exception).__name__}: {str(last_exception)[:200]}")
|
175
|
+
raise last_exception
|
176
|
+
|
177
|
+
|
178
|
+
def create_message(content: Any, message_type: str, origin_system_id: Any, turn: int) -> SessionEventMarkovBlanketMessage:
|
179
|
+
"""Create a message with origin system ID embedded in content."""
|
180
|
+
return SessionEventMarkovBlanketMessage(
|
181
|
+
content={
|
182
|
+
"origin_system_id": str(origin_system_id),
|
183
|
+
"payload": content
|
184
|
+
},
|
185
|
+
message_type=message_type,
|
186
|
+
time_record=TimeRecord(
|
187
|
+
event_time=datetime.now().isoformat(),
|
188
|
+
message_time=turn
|
189
|
+
)
|
190
|
+
)
|
191
|
+
|
192
|
+
|
193
|
+
def compress_observation_for_trace(obs: Dict[str, Any]) -> Dict[str, Any]:
|
194
|
+
"""Compress observation for trace storage to avoid huge trace files."""
|
195
|
+
compressed = obs.copy()
|
196
|
+
|
197
|
+
# Compress semantic map if present
|
198
|
+
if "semantic_map" in compressed:
|
199
|
+
del compressed["semantic_map"]
|
200
|
+
|
201
|
+
# Compress other large fields
|
202
|
+
if "rgb" in compressed:
|
203
|
+
del compressed["rgb"]
|
204
|
+
|
205
|
+
return compressed
|
206
|
+
|
207
|
+
|
208
|
+
def format_semantic_map_view_v2(obs: Dict[str, Any], view_size: int = 7) -> str:
|
209
|
+
"""Format a semantic map view around the player with normal names."""
|
210
|
+
# Get semantic map
|
211
|
+
semantic_map = obs.get("semantic_map")
|
212
|
+
if semantic_map is None:
|
213
|
+
return "No semantic map available"
|
214
|
+
|
215
|
+
# Convert to numpy array if needed
|
216
|
+
sem_arr = np.asarray(semantic_map)
|
217
|
+
if sem_arr.ndim == 1:
|
218
|
+
# Assuming square map, reshape
|
219
|
+
size = int(np.sqrt(sem_arr.size))
|
220
|
+
sem_arr = sem_arr.reshape(size, size)
|
221
|
+
|
222
|
+
# Get player position
|
223
|
+
player_pos = obs.get("player_position", [sem_arr.shape[0]//2, sem_arr.shape[1]//2])
|
224
|
+
px, py = int(player_pos[0]), int(player_pos[1])
|
225
|
+
|
226
|
+
# Create view
|
227
|
+
half = view_size // 2
|
228
|
+
lines = []
|
229
|
+
visible_items = set()
|
230
|
+
|
231
|
+
# Map of semantic indices to normal names (not symbols)
|
232
|
+
name_map = {
|
233
|
+
0: 'grass', # Empty/grass
|
234
|
+
1: 'tree', # Tree
|
235
|
+
2: 'stone', # Stone
|
236
|
+
3: 'coal', # Coal
|
237
|
+
4: 'iron', # Iron
|
238
|
+
5: 'table', # Crafting table
|
239
|
+
6: 'furnace', # Furnace
|
240
|
+
7: 'diamond', # Diamond
|
241
|
+
8: 'water', # Water
|
242
|
+
9: 'lava', # Lava
|
243
|
+
10: 'sand', # Sand
|
244
|
+
11: 'zombie', # Enemy/Zombie
|
245
|
+
12: 'skeleton', # Skeleton
|
246
|
+
13: 'cow', # Cow
|
247
|
+
14: 'unknown', # Unknown/Other
|
248
|
+
}
|
249
|
+
|
250
|
+
for dy in range(-half, half + 1):
|
251
|
+
row = []
|
252
|
+
for dx in range(-half, half + 1):
|
253
|
+
x, y = px + dx, py + dy
|
254
|
+
|
255
|
+
if dx == 0 and dy == 0:
|
256
|
+
row.append('you') # Player
|
257
|
+
elif 0 <= x < sem_arr.shape[0] and 0 <= y < sem_arr.shape[1]:
|
258
|
+
val = int(sem_arr[x, y])
|
259
|
+
item_name = name_map.get(val, 'unknown')
|
260
|
+
row.append(item_name)
|
261
|
+
if item_name not in ['grass', 'you']:
|
262
|
+
visible_items.add(item_name)
|
263
|
+
else:
|
264
|
+
row.append('void') # Out of bounds
|
265
|
+
|
266
|
+
lines.append(' '.join(row))
|
267
|
+
|
268
|
+
# Add legend of visible items
|
269
|
+
legend = f"Visible items: {', '.join(sorted(visible_items))}" if visible_items else "No special items visible (mostly grass)"
|
270
|
+
|
271
|
+
return "\n".join(lines) + "\n" + legend
|
272
|
+
|
273
|
+
|
274
|
+
def get_openai_tools():
|
275
|
+
"""Get OpenAI-compatible tool definitions for Synth models."""
|
276
|
+
return [
|
277
|
+
{
|
278
|
+
"type": "function",
|
279
|
+
"function": {
|
280
|
+
"name": "interact",
|
281
|
+
"description": "Perform actions in the Crafter environment.",
|
282
|
+
"parameters": {
|
283
|
+
"type": "object",
|
284
|
+
"properties": {
|
285
|
+
"actions": {
|
286
|
+
"type": "array",
|
287
|
+
"items": {
|
288
|
+
"type": "string"
|
289
|
+
},
|
290
|
+
"description": "List of actions to perform in sequence (e.g., ['move_right', 'move_right', 'do']). Available actions: move_left, move_right, move_up, move_down, do, sleep, place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword, noop"
|
291
|
+
},
|
292
|
+
"reasoning": {
|
293
|
+
"type": "string",
|
294
|
+
"description": "Reasoning for these actions"
|
295
|
+
}
|
296
|
+
},
|
297
|
+
"required": ["actions", "reasoning"]
|
298
|
+
}
|
299
|
+
}
|
300
|
+
},
|
301
|
+
{
|
302
|
+
"type": "function",
|
303
|
+
"function": {
|
304
|
+
"name": "terminate",
|
305
|
+
"description": "End the episode when finished or no progress can be made.",
|
306
|
+
"parameters": {
|
307
|
+
"type": "object",
|
308
|
+
"properties": {
|
309
|
+
"reason": {
|
310
|
+
"type": "string",
|
311
|
+
"description": "Reason for termination"
|
312
|
+
}
|
313
|
+
},
|
314
|
+
"required": ["reason"]
|
315
|
+
}
|
316
|
+
}
|
317
|
+
}
|
318
|
+
]
|
319
|
+
|
320
|
+
|
321
|
+
# --- Configuration Class ---
|
322
|
+
class CrafterConfig:
|
323
|
+
"""Configuration for Crafter evaluation with Synth backend."""
|
324
|
+
|
325
|
+
def __init__(self, config_path: Optional[str] = None):
|
326
|
+
# Default values
|
327
|
+
self.model_name: Optional[str] = None
|
328
|
+
self.num_instances = 1
|
329
|
+
self.max_turns = 2
|
330
|
+
self.difficulty = "easy"
|
331
|
+
self.service_base_url = "http://localhost:8901"
|
332
|
+
self.service_timeout = 30.0
|
333
|
+
self.seed = 42
|
334
|
+
self.save_traces = True
|
335
|
+
self.save_detailed_results = True
|
336
|
+
self.verbose = False
|
337
|
+
self.quiet = False # Add quiet mode support
|
338
|
+
self.analyze_traces = False
|
339
|
+
|
340
|
+
# V2 tracing settings
|
341
|
+
self.enable_v2_tracing = True
|
342
|
+
self.v2_trace_dir = "./traces_v2_lm_synth"
|
343
|
+
self.duckdb_only = True # Store in DuckDB only, no individual JSON files
|
344
|
+
self.auto_cleanup = True # Clean up old files automatically
|
345
|
+
|
346
|
+
# Synth-specific settings
|
347
|
+
self.warmup_model = True
|
348
|
+
self.warmup_max_attempts = 30
|
349
|
+
self.warmup_timeout = 60.0 # Default timeout in seconds
|
350
|
+
self.use_synth_backend = True # Flag to indicate Synth backend
|
351
|
+
|
352
|
+
# Load from TOML if provided
|
353
|
+
if config_path and os.path.exists(config_path):
|
354
|
+
self.load_from_toml(config_path)
|
355
|
+
|
356
|
+
def load_from_toml(self, config_path: str):
|
357
|
+
"""Load configuration from TOML file."""
|
358
|
+
config = toml.load(config_path)
|
359
|
+
|
360
|
+
eval_config = config.get("eval", {})
|
361
|
+
self.model_name = eval_config.get("model_name", self.model_name)
|
362
|
+
self.num_instances = eval_config.get("episodes", self.num_instances)
|
363
|
+
self.max_turns = eval_config.get("max_steps", self.max_turns)
|
364
|
+
self.difficulty = eval_config.get("difficulty", self.difficulty)
|
365
|
+
self.seed = eval_config.get("seed", self.seed)
|
366
|
+
|
367
|
+
service_config = config.get("service", {})
|
368
|
+
self.service_base_url = service_config.get("base_url", self.service_base_url)
|
369
|
+
self.service_timeout = service_config.get("timeout", self.service_timeout)
|
370
|
+
|
371
|
+
output_config = config.get("output", {})
|
372
|
+
self.save_traces = output_config.get("save_traces", self.save_traces)
|
373
|
+
self.save_detailed_results = output_config.get(
|
374
|
+
"save_detailed_results", self.save_detailed_results
|
375
|
+
)
|
376
|
+
|
377
|
+
# V2 tracing config
|
378
|
+
tracing_config = config.get("tracing_v2", {})
|
379
|
+
self.enable_v2_tracing = tracing_config.get("enabled", self.enable_v2_tracing)
|
380
|
+
self.v2_trace_dir = tracing_config.get("trace_dir", self.v2_trace_dir)
|
381
|
+
self.duckdb_only = tracing_config.get("duckdb_only", self.duckdb_only)
|
382
|
+
self.auto_cleanup = tracing_config.get("auto_cleanup", self.auto_cleanup)
|
383
|
+
|
384
|
+
# Synth config
|
385
|
+
synth_config = config.get("synth", {})
|
386
|
+
self.warmup_model = synth_config.get("warmup_model", self.warmup_model)
|
387
|
+
self.warmup_max_attempts = synth_config.get("warmup_max_attempts", self.warmup_max_attempts)
|
388
|
+
self.warmup_timeout = synth_config.get("warmup_timeout", self.warmup_timeout)
|
389
|
+
self.use_synth_backend = synth_config.get("use_synth_backend", self.use_synth_backend)
|
390
|
+
|
391
|
+
|
392
|
+
# --- Base ReAct Agent using LM with Synth ---
|
393
|
+
class BaseReActAgentWithLMSynth:
|
394
|
+
"""Base ReAct agent using LM class configured for Synth backend."""
|
395
|
+
|
396
|
+
def __init__(self, model_name: str, max_turns: int = 20, verbose: bool = False,
|
397
|
+
tracer: Optional[SessionTracer] = None, episode_id: int = 0, quiet: bool = False):
|
398
|
+
self.model_name = model_name
|
399
|
+
self.max_turns = max_turns
|
400
|
+
self.verbose = verbose
|
401
|
+
self.quiet = quiet
|
402
|
+
self.history = []
|
403
|
+
self.system_name = "base-react-agent-lm-synth"
|
404
|
+
self.tools = get_openai_tools()
|
405
|
+
self.tracer = tracer
|
406
|
+
self.system_id = f"{self.system_name}_{uuid.uuid4()}"
|
407
|
+
self.episode_id = episode_id
|
408
|
+
|
409
|
+
# Setup Synth environment variables
|
410
|
+
setup_synth_environment()
|
411
|
+
|
412
|
+
# Create LM instance with synth provider
|
413
|
+
self.lm = LM(
|
414
|
+
model_name=model_name,
|
415
|
+
formatting_model_name=model_name,
|
416
|
+
temperature=0.7, # Add some randomness to prevent identical responses
|
417
|
+
synth_logging=False, # Disable v1 tracing
|
418
|
+
provider="synth", # Use synth provider
|
419
|
+
session_tracer=tracer,
|
420
|
+
system_id=self.system_id,
|
421
|
+
enable_v2_tracing=True,
|
422
|
+
)
|
423
|
+
|
424
|
+
# Agent state tracking
|
425
|
+
self.agent_state = {
|
426
|
+
"message_history": [],
|
427
|
+
"steps_taken": 0,
|
428
|
+
"steps_remaining": max_turns,
|
429
|
+
"total_tokens_used": 0,
|
430
|
+
"tool_calls_made": 0,
|
431
|
+
"current_turn": 0
|
432
|
+
}
|
433
|
+
|
434
|
+
async def decide(self, obs: str, system_message: str, turn: int) -> Dict[str, Any]:
|
435
|
+
"""Get agent decision based on observation using LM class with Synth backend."""
|
436
|
+
# Update agent state
|
437
|
+
self.agent_state["current_turn"] = turn
|
438
|
+
self.agent_state["steps_taken"] = turn
|
439
|
+
self.agent_state["steps_remaining"] = self.max_turns - turn
|
440
|
+
|
441
|
+
# Create conversation context with unique episode ID to prevent caching
|
442
|
+
context = f"Episode {self.episode_id} - Turn {turn + 1}/{self.max_turns}\n\n{obs}"
|
443
|
+
|
444
|
+
# Build messages in OpenAI format for tools
|
445
|
+
messages = [
|
446
|
+
{"role": "system", "content": system_message},
|
447
|
+
{"role": "user", "content": context}
|
448
|
+
]
|
449
|
+
|
450
|
+
# Add to message history
|
451
|
+
self.agent_state["message_history"].extend(messages)
|
452
|
+
|
453
|
+
# Truncate history if too long
|
454
|
+
max_history_length = 20
|
455
|
+
if len(self.agent_state["message_history"]) > max_history_length:
|
456
|
+
self.agent_state["message_history"] = (
|
457
|
+
[self.agent_state["message_history"][0]] +
|
458
|
+
self.agent_state["message_history"][-(max_history_length-1):]
|
459
|
+
)
|
460
|
+
|
461
|
+
try:
|
462
|
+
llm_start = time.time()
|
463
|
+
|
464
|
+
# Only show LM call logs if verbose enabled
|
465
|
+
# Note: self.verbose is not directly available but could be passed in
|
466
|
+
|
467
|
+
# Print the full prompt on the final turn to debug achievements
|
468
|
+
if turn == self.max_turns - 1:
|
469
|
+
print("\n🔍 FINAL TURN PROMPT:")
|
470
|
+
print("="*80)
|
471
|
+
print(f"System: {system_message[:200]}...")
|
472
|
+
print(f"\nUser message:\n{context}")
|
473
|
+
print("="*80)
|
474
|
+
|
475
|
+
# Call LM with turn number for v2 tracing
|
476
|
+
# The LM class should handle Synth routing internally
|
477
|
+
response = await self.lm.respond_async(
|
478
|
+
messages=messages,
|
479
|
+
turn_number=turn,
|
480
|
+
# Pass tools in the format expected by LM class
|
481
|
+
# This might need adjustment based on LM implementation
|
482
|
+
tools=self.tools
|
483
|
+
)
|
484
|
+
|
485
|
+
llm_end = time.time()
|
486
|
+
|
487
|
+
# Parse the response to extract tool calls
|
488
|
+
# The LM class returns a BaseLMResponse
|
489
|
+
raw_response = response.raw_response
|
490
|
+
|
491
|
+
# Check if response has tool calls
|
492
|
+
if hasattr(response, 'tool_calls') and response.tool_calls:
|
493
|
+
# Parse tool calls from response
|
494
|
+
tool_call = response.tool_calls[0]
|
495
|
+
decision = {
|
496
|
+
"name": tool_call.get("name", "interact"),
|
497
|
+
"parameters": tool_call.get("parameters", {})
|
498
|
+
}
|
499
|
+
else:
|
500
|
+
# Parse from raw response
|
501
|
+
decision = self._parse_tool_response(raw_response)
|
502
|
+
|
503
|
+
# Update agent state
|
504
|
+
self.agent_state["tool_calls_made"] += 1
|
505
|
+
|
506
|
+
# Add assistant response to history
|
507
|
+
assistant_message = {
|
508
|
+
"role": "assistant",
|
509
|
+
"content": raw_response
|
510
|
+
}
|
511
|
+
self.agent_state["message_history"].append(assistant_message)
|
512
|
+
|
513
|
+
if self.verbose:
|
514
|
+
print(f"🤖 LM Response (turn {turn}): {json.dumps(decision, indent=2)}")
|
515
|
+
print(f"📊 Response time: {llm_end - llm_start:.2f}s")
|
516
|
+
|
517
|
+
# Suppress noisy tool call logs - only show minimal info
|
518
|
+
if not self.quiet:
|
519
|
+
print(f"\n🔧 Turn {turn + 1} - Tool Call: {decision['name']}")
|
520
|
+
if decision['name'] == 'interact':
|
521
|
+
print(f" Actions: {decision['parameters'].get('actions', [])}")
|
522
|
+
print(f" Reasoning: {decision['parameters'].get('reasoning', 'No reasoning provided')}")
|
523
|
+
elif decision['name'] == 'terminate':
|
524
|
+
print(f" Reason: {decision['parameters'].get('reason', 'No reason provided')}")
|
525
|
+
|
526
|
+
except Exception as e:
|
527
|
+
print(f"❌ Error in LM decide: {e}")
|
528
|
+
import traceback
|
529
|
+
traceback.print_exc()
|
530
|
+
# Fallback decision
|
531
|
+
decision = {
|
532
|
+
"name": "interact",
|
533
|
+
"parameters": {
|
534
|
+
"actions": ["do"],
|
535
|
+
"reasoning": f"Error occurred: {str(e)}"
|
536
|
+
}
|
537
|
+
}
|
538
|
+
|
539
|
+
return decision
|
540
|
+
|
541
|
+
def _parse_tool_response(self, raw_response: str) -> Dict[str, Any]:
|
542
|
+
"""Parse raw LM response to extract tool calls."""
|
543
|
+
# Try to parse JSON if present
|
544
|
+
try:
|
545
|
+
# Look for JSON in the response
|
546
|
+
import re
|
547
|
+
json_match = re.search(r'\{.*\}', raw_response, re.DOTALL)
|
548
|
+
if json_match:
|
549
|
+
data = json.loads(json_match.group())
|
550
|
+
if "name" in data:
|
551
|
+
return data
|
552
|
+
elif "function" in data:
|
553
|
+
return {
|
554
|
+
"name": data["function"].get("name", "interact"),
|
555
|
+
"parameters": data["function"].get("arguments", {})
|
556
|
+
}
|
557
|
+
except:
|
558
|
+
pass
|
559
|
+
|
560
|
+
# Fallback to text parsing
|
561
|
+
if "terminate" in raw_response.lower():
|
562
|
+
return {
|
563
|
+
"name": "terminate",
|
564
|
+
"parameters": {
|
565
|
+
"reason": "Agent decided to terminate"
|
566
|
+
}
|
567
|
+
}
|
568
|
+
|
569
|
+
# Try to extract actions from the response
|
570
|
+
actions = []
|
571
|
+
action_keywords = [
|
572
|
+
"move_up", "move_down", "move_left", "move_right", "do", "sleep",
|
573
|
+
"place_stone", "place_table", "place_furnace", "place_plant",
|
574
|
+
"make_wood_pickaxe", "make_stone_pickaxe", "make_iron_pickaxe",
|
575
|
+
"make_wood_sword", "make_stone_sword", "make_iron_sword"
|
576
|
+
]
|
577
|
+
|
578
|
+
for keyword in action_keywords:
|
579
|
+
if keyword in raw_response.lower():
|
580
|
+
actions.append(keyword)
|
581
|
+
|
582
|
+
if not actions:
|
583
|
+
actions = ["do"] # Default action
|
584
|
+
|
585
|
+
return {
|
586
|
+
"name": "interact",
|
587
|
+
"parameters": {
|
588
|
+
"actions": actions, # Return as array of actions
|
589
|
+
"reasoning": "Parsed from response"
|
590
|
+
}
|
591
|
+
}
|
592
|
+
|
593
|
+
def get_system_message(self) -> str:
|
594
|
+
"""Return system message for agent. Override in subclasses."""
|
595
|
+
return """You are an AI agent playing Crafter. Use the available tools to interact with the environment.
|
596
|
+
|
597
|
+
CRITICAL RULE: You MUST provide MULTIPLE actions (2-5) in EVERY interact() tool call!
|
598
|
+
|
599
|
+
The 'interact' function accepts a LIST of 1-5 actions. ALWAYS provide 2-5 actions for efficiency.
|
600
|
+
|
601
|
+
GOOD Examples (what you SHOULD do):
|
602
|
+
✓ interact(actions=["move_right", "move_right", "do"], reasoning="Move to tree and collect wood")
|
603
|
+
✓ interact(actions=["move_up", "move_up", "move_right", "do"], reasoning="Navigate to stone and mine it")
|
604
|
+
✓ interact(actions=["place_table", "make_wood_pickaxe", "move_left"], reasoning="Craft and continue exploring")
|
605
|
+
|
606
|
+
BAD Examples (what you should AVOID):
|
607
|
+
✗ interact(actions=["move_right"], reasoning="Move right") - TOO FEW ACTIONS!
|
608
|
+
✗ interact(actions=["do"], reasoning="Collect") - TOO FEW ACTIONS!
|
609
|
+
|
610
|
+
REMEMBER: Single actions waste time. Always plan 2-5 actions ahead and execute them together!"""
|
611
|
+
|
612
|
+
def format_observation(self, obs: Dict[str, Any]) -> str:
|
613
|
+
"""Format observation for agent. Override in subclasses."""
|
614
|
+
return str(obs)
|
615
|
+
|
616
|
+
|
617
|
+
# --- Crafter-specific ReAct Agent ---
|
618
|
+
class CrafterReActAgentWithLMSynth(BaseReActAgentWithLMSynth):
|
619
|
+
"""Crafter-specific ReAct agent with enhanced prompting for Synth models."""
|
620
|
+
|
621
|
+
def get_system_message(self) -> str:
|
622
|
+
"""Return Crafter-specific system message optimized for Synth models."""
|
623
|
+
return """You are CrafterAgent playing Crafter survival environment. Your goal is to unlock as many achievements as possible while staying alive.
|
624
|
+
|
625
|
+
You will see a semantic map view showing your surroundings. Use this to navigate toward resources.
|
626
|
+
|
627
|
+
Key mechanics:
|
628
|
+
• 'do' action: collect wood from trees, stone from deposits, food from cows/plants
|
629
|
+
• 'do' does nothing on grass/water - move to find resources first
|
630
|
+
• Craft progression: wood → table → wood_pickaxe → stone → stone_pickaxe → iron tools
|
631
|
+
• Sleep when energy low to restore and unlock wake_up achievement
|
632
|
+
• Use semantic map view to navigate toward resources you can see
|
633
|
+
|
634
|
+
Available actions: move_left, move_right, move_up, move_down, do, sleep, place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword, noop
|
635
|
+
|
636
|
+
KEY ACHIEVEMENTS TO UNLOCK:
|
637
|
+
Basic Resource Collection (PRIORITY #1):
|
638
|
+
- collect_wood: Move NEXT TO a tree, then use action="do" to collect wood
|
639
|
+
- collect_stone: Move NEXT TO stone, then use action="do" (requires wood_pickaxe in inventory)
|
640
|
+
- collect_coal: Move NEXT TO coal, then use action="do" (requires stone_pickaxe)
|
641
|
+
- collect_iron: Move NEXT TO iron, then use action="do" (requires stone_pickaxe)
|
642
|
+
- collect_diamond: Move NEXT TO diamond, then use action="do" (requires iron_pickaxe)
|
643
|
+
|
644
|
+
Tool Crafting (enables resource collection):
|
645
|
+
- make_wood_pickaxe: Use action="make_wood_pickaxe" when you have wood (unlocks ability to mine stone)
|
646
|
+
- make_stone_pickaxe: Use action="make_stone_pickaxe" when you have wood and stone (unlocks coal/iron mining)
|
647
|
+
- make_iron_pickaxe: Use action="make_iron_pickaxe" when you have wood, coal, and iron (unlocks diamond mining)
|
648
|
+
|
649
|
+
Weapon Crafting (for defense):
|
650
|
+
- make_wood_sword: Use action="make_wood_sword" when you have wood
|
651
|
+
- make_stone_sword: Use action="make_stone_sword" when you have wood and stone
|
652
|
+
- make_iron_sword: Use action="make_iron_sword" when you have wood, coal, and iron
|
653
|
+
|
654
|
+
Survival Actions:
|
655
|
+
- eat_plant: Use action="eat_plant" when food < 9 and you see a plant nearby
|
656
|
+
- eat_cow: Move NEXT TO cow, use action="do" to kill it, then action="eat_cow"
|
657
|
+
- collect_drink: Move NEXT TO water, then use action="drink" when drink < 9
|
658
|
+
- sleep: Use action="sleep" when energy < 5 (restores energy to 9)
|
659
|
+
|
660
|
+
Building/Placing:
|
661
|
+
- place_table: Use action="place_table" when you have wood (enables advanced crafting)
|
662
|
+
- place_furnace: Use action="place_furnace" when you have stone (for smelting)
|
663
|
+
- place_plant: Use action="place_plant" when you have sapling (grows into tree)
|
664
|
+
- place_stone: Use action="place_stone" when you have stone (creates barrier)
|
665
|
+
|
666
|
+
Combat:
|
667
|
+
- defeat_zombie: Move NEXT TO zombie, then use action="do" repeatedly to attack
|
668
|
+
- defeat_skeleton: Move NEXT TO skeleton, then use action="do" repeatedly to attack
|
669
|
+
|
670
|
+
CRITICAL: The action="do" is your INTERACTION button! Use it when adjacent to:
|
671
|
+
- Trees → get wood
|
672
|
+
- Stone/Coal/Iron/Diamond → mine resources (need appropriate pickaxe)
|
673
|
+
- Enemies → attack them
|
674
|
+
- Cows → kill for food
|
675
|
+
|
676
|
+
Simple Strategy:
|
677
|
+
1. Look for resources (trees, stones) in the semantic map
|
678
|
+
2. Move toward the nearest resource
|
679
|
+
3. When adjacent to a resource, use action="do" to collect it
|
680
|
+
4. If you have wood, try action="make_wood_pickaxe"
|
681
|
+
5. Repeat: find resources, move to them, use "do"
|
682
|
+
|
683
|
+
Critical Gameplay Tips:
|
684
|
+
- You must be ADJACENT (one tile away) to objects to interact with them
|
685
|
+
- Use "do" when next to: trees (for wood), stone (for stone), coal, iron, diamond
|
686
|
+
- Use "do" to attack zombies/skeletons when adjacent
|
687
|
+
- First priority: Find a tree, move next to it, then use "do" to collect wood
|
688
|
+
- Wood is essential for crafting your first pickaxe
|
689
|
+
- With wood_pickaxe you can mine stone, with stone_pickaxe you can mine iron, etc.
|
690
|
+
|
691
|
+
CRITICAL INSTRUCTION: You MUST ALWAYS provide MULTIPLE actions (2-5) in EVERY interact() tool call!
|
692
|
+
|
693
|
+
The 'interact' function accepts a LIST of 1-5 actions. NEVER use single actions - always plan 2-5 actions ahead!
|
694
|
+
|
695
|
+
MANDATORY action sequences (ALWAYS use multiple):
|
696
|
+
✓ interact(actions=["move_right", "move_right", "do"], reasoning="Move to tree and collect wood")
|
697
|
+
✓ interact(actions=["move_up", "move_up", "move_right", "do"], reasoning="Navigate and collect")
|
698
|
+
✓ interact(actions=["place_table", "make_wood_pickaxe", "move_left", "move_left"], reasoning="Craft and explore")
|
699
|
+
✓ interact(actions=["do", "move_right", "do", "move_right", "do"], reasoning="Collect multiple resources")
|
700
|
+
|
701
|
+
FORBIDDEN (NEVER do this):
|
702
|
+
✗ interact(actions=["move_right"], ...) - WRONG! Too few actions!
|
703
|
+
✗ interact(actions=["do"], ...) - WRONG! Too few actions!
|
704
|
+
|
705
|
+
RULE: If you use less than 2 actions, you are playing inefficiently. Always think 2-5 steps ahead!
|
706
|
+
|
707
|
+
Key Strategy:
|
708
|
+
1. Plan a sequence of moves to reach resources
|
709
|
+
2. Execute multiple moves in one tool call (e.g., ["move_right", "move_right", "move_up"])
|
710
|
+
3. When adjacent to a resource, use "do" to collect it
|
711
|
+
4. Chain crafting actions together (e.g., ["place_table", "make_wood_pickaxe"])
|
712
|
+
|
713
|
+
Remember:
|
714
|
+
- Use "do" when ADJACENT to trees (for wood), stones, or other resources
|
715
|
+
- Collect wood FIRST before trying to craft anything
|
716
|
+
- Be efficient - use multiple actions per tool call!
|
717
|
+
- Focus on unlocking achievements by collecting resources and crafting items."""
|
718
|
+
|
719
|
+
def format_observation(self, obs: Dict[str, Any]) -> str:
|
720
|
+
"""Format Crafter observation with semantic map view."""
|
721
|
+
# Get semantic map view
|
722
|
+
semantic_view = format_semantic_map_view_v2(obs, view_size=7)
|
723
|
+
|
724
|
+
# Extract key information
|
725
|
+
inventory = obs.get('inventory', {})
|
726
|
+
# Try both possible keys for achievements
|
727
|
+
achievements = obs.get('achievements_status', obs.get('achievements_info', {}))
|
728
|
+
health = obs.get('health', 10)
|
729
|
+
food = obs.get('food', 10)
|
730
|
+
drink = obs.get('drink', 10)
|
731
|
+
energy = obs.get('energy', 10)
|
732
|
+
|
733
|
+
# Count achievements
|
734
|
+
achieved = sum(1 for v in achievements.values() if v)
|
735
|
+
total_achievements = len(achievements)
|
736
|
+
|
737
|
+
# Format inventory (only show non-zero items)
|
738
|
+
inv_items = []
|
739
|
+
for item, count in inventory.items():
|
740
|
+
if count > 0:
|
741
|
+
inv_items.append(f"{item}: {count}")
|
742
|
+
inv_str = ", ".join(inv_items) if inv_items else "empty"
|
743
|
+
|
744
|
+
# List unlocked achievements
|
745
|
+
unlocked = [k for k, v in achievements.items() if v]
|
746
|
+
unlocked_str = ", ".join(unlocked) if unlocked else "none"
|
747
|
+
|
748
|
+
# Recent achievements (from info if available)
|
749
|
+
recent_str = ""
|
750
|
+
|
751
|
+
return f"""=== SEMANTIC MAP VIEW (15x15) ===
|
752
|
+
{semantic_view}
|
753
|
+
|
754
|
+
=== STATUS ===
|
755
|
+
Health: {health}/10 | Food: {food}/10 | Drink: {drink}/10 | Energy: {energy}/10
|
756
|
+
Inventory: {inv_str}
|
757
|
+
Achievements: {achieved}/{total_achievements} unlocked
|
758
|
+
Unlocked: {unlocked_str}
|
759
|
+
{recent_str}
|
760
|
+
|
761
|
+
What do you see in the map? What actions should you take?
|
762
|
+
|
763
|
+
REMINDER: You MUST provide 2-5 actions in your interact() tool call. Plan multiple steps ahead!
|
764
|
+
Example: interact(actions=["move_right", "move_right", "do"], reasoning="Move to tree and collect wood")"""
|
765
|
+
|
766
|
+
|
767
|
+
async def run_episode(
|
768
|
+
episode_id: int,
|
769
|
+
config: CrafterConfig,
|
770
|
+
session_tracer: Optional[SessionTracer] = None,
|
771
|
+
progress_bar: Optional[tqdm] = None,
|
772
|
+
quiet: bool = False
|
773
|
+
):
|
774
|
+
"""Run a single episode."""
|
775
|
+
episode_start_time = time.time()
|
776
|
+
|
777
|
+
# Create agent
|
778
|
+
agent = CrafterReActAgentWithLMSynth(
|
779
|
+
model_name=config.model_name,
|
780
|
+
max_turns=config.max_turns,
|
781
|
+
verbose=config.verbose,
|
782
|
+
tracer=session_tracer,
|
783
|
+
episode_id=episode_id,
|
784
|
+
quiet=quiet
|
785
|
+
)
|
786
|
+
|
787
|
+
# Initialize environment
|
788
|
+
async with AsyncClient(base_url=config.service_base_url) as client:
|
789
|
+
try:
|
790
|
+
# Initialize environment with unique seed for each episode
|
791
|
+
# Use simple sequential seeds: 1, 2, 3, 4, etc.
|
792
|
+
episode_seed = episode_id + 1 # Start from 1 instead of 0
|
793
|
+
|
794
|
+
init_response = await retry_http_request(
|
795
|
+
client, "POST", "/env/CrafterClassic/initialize",
|
796
|
+
json={
|
797
|
+
"config": {
|
798
|
+
"difficulty": config.difficulty,
|
799
|
+
"seed": episode_seed
|
800
|
+
}
|
801
|
+
}
|
802
|
+
)
|
803
|
+
|
804
|
+
if config.verbose and episode_id == 0 and not quiet:
|
805
|
+
print(f"🎲 Episode {episode_id} using seed: {episode_seed}")
|
806
|
+
init_data = init_response.json()
|
807
|
+
instance_id = init_data["env_id"]
|
808
|
+
obs = init_data["observation"]
|
809
|
+
|
810
|
+
# Debug: print first observation structure (only for first episode)
|
811
|
+
if config.verbose and episode_id == 0 and not quiet:
|
812
|
+
print(f"\n🔍 First observation keys: {list(obs.keys())}")
|
813
|
+
if 'inventory' in obs:
|
814
|
+
inv = obs['inventory']
|
815
|
+
non_zero = {k: v for k, v in inv.items() if v > 0}
|
816
|
+
print(f"📦 Starting inventory: {non_zero if non_zero else 'Empty'}")
|
817
|
+
if 'achievements_status' in obs:
|
818
|
+
print(f"🏆 Achievement keys: {list(obs['achievements_status'].keys())[:5]}...")
|
819
|
+
|
820
|
+
# Start initial timestep and send initial observation as message
|
821
|
+
if session_tracer and session_tracer.current_session:
|
822
|
+
session_tracer.start_timestep(0) # Start timestep for turn 0
|
823
|
+
obs_msg = create_message(
|
824
|
+
compress_observation_for_trace(obs),
|
825
|
+
"observation",
|
826
|
+
f"crafter_env_{instance_id}",
|
827
|
+
0
|
828
|
+
)
|
829
|
+
session_tracer.record_message(obs_msg)
|
830
|
+
|
831
|
+
# Run episode
|
832
|
+
episode_reward = 0
|
833
|
+
termination_reason = None
|
834
|
+
step_results = []
|
835
|
+
|
836
|
+
for turn in range(config.max_turns):
|
837
|
+
if progress_bar:
|
838
|
+
progress_bar.set_description(f"Episode {episode_id}: Step {turn+1}/{config.max_turns}")
|
839
|
+
elif config.verbose and turn % 5 == 0 and not quiet: # Print progress every 5 steps when no progress bar
|
840
|
+
print(f" Episode {episode_id}: Step {turn+1}/{config.max_turns}")
|
841
|
+
|
842
|
+
# Start timestep for this turn if not turn 0
|
843
|
+
if turn > 0 and session_tracer and session_tracer.current_session:
|
844
|
+
session_tracer.start_timestep(turn)
|
845
|
+
|
846
|
+
# Get agent decision
|
847
|
+
obs_formatted = agent.format_observation(obs)
|
848
|
+
system_msg = agent.get_system_message()
|
849
|
+
|
850
|
+
decision = await agent.decide(obs_formatted, system_msg, turn)
|
851
|
+
|
852
|
+
# Handle termination
|
853
|
+
if decision["name"] == "terminate":
|
854
|
+
termination_reason = decision["parameters"]["reason"]
|
855
|
+
break
|
856
|
+
|
857
|
+
# Execute actions in sequence
|
858
|
+
actions = decision["parameters"]["actions"]
|
859
|
+
|
860
|
+
# Define action mapping
|
861
|
+
CRAFTER_ACTION_MAP = {
|
862
|
+
"noop": 0,
|
863
|
+
"move_left": 1,
|
864
|
+
"move_right": 2,
|
865
|
+
"move_up": 3,
|
866
|
+
"move_down": 4,
|
867
|
+
"do": 5,
|
868
|
+
"sleep": 6,
|
869
|
+
"place_stone": 7,
|
870
|
+
"place_table": 8,
|
871
|
+
"place_furnace": 9,
|
872
|
+
"place_plant": 10,
|
873
|
+
"make_wood_pickaxe": 11,
|
874
|
+
"make_stone_pickaxe": 12,
|
875
|
+
"make_iron_pickaxe": 13,
|
876
|
+
"make_wood_sword": 14,
|
877
|
+
"make_stone_sword": 15,
|
878
|
+
"make_iron_sword": 16,
|
879
|
+
}
|
880
|
+
|
881
|
+
# Execute each action in the sequence
|
882
|
+
for action in actions:
|
883
|
+
# Convert action name to integer
|
884
|
+
action_int = CRAFTER_ACTION_MAP.get(action, 0) # Default to noop
|
885
|
+
|
886
|
+
# Get state before action
|
887
|
+
state_before = {"observation": obs} if 'obs' in locals() else {}
|
888
|
+
prev_obs = obs.copy()
|
889
|
+
|
890
|
+
# Step environment
|
891
|
+
step_response = await retry_http_request(
|
892
|
+
client, "POST", "/env/CrafterClassic/step",
|
893
|
+
json={
|
894
|
+
"env_id": instance_id,
|
895
|
+
"action": {"tool_calls": [{"tool": "interact", "args": {"action": action_int}}]}
|
896
|
+
}
|
897
|
+
)
|
898
|
+
step_data = step_response.json()
|
899
|
+
|
900
|
+
if config.verbose and not quiet:
|
901
|
+
print(f"Step response keys: {list(step_data.keys())}")
|
902
|
+
# Create a cleaned version for logging (exclude large arrays)
|
903
|
+
step_data_clean = {}
|
904
|
+
for key, value in step_data.items():
|
905
|
+
if key == "observation" and isinstance(value, dict):
|
906
|
+
obs_clean = {}
|
907
|
+
for obs_key, obs_value in value.items():
|
908
|
+
if obs_key == "semantic_map":
|
909
|
+
obs_clean[obs_key] = f"<semantic_map: {getattr(obs_value, 'shape', 'array')}>"
|
910
|
+
elif hasattr(obs_value, '__len__') and len(str(obs_value)) > 200:
|
911
|
+
obs_clean[obs_key] = f"<large_array: {type(obs_value).__name__}>"
|
912
|
+
else:
|
913
|
+
obs_clean[obs_key] = obs_value
|
914
|
+
step_data_clean[key] = obs_clean
|
915
|
+
else:
|
916
|
+
step_data_clean[key] = value
|
917
|
+
print(f"Step response: {step_data_clean}")
|
918
|
+
|
919
|
+
obs = step_data["observation"]
|
920
|
+
reward = step_data.get("reward", 0) # Default to 0 if None
|
921
|
+
done = step_data["done"]
|
922
|
+
info = step_data.get("info", {})
|
923
|
+
|
924
|
+
if reward is not None:
|
925
|
+
episode_reward += reward
|
926
|
+
|
927
|
+
# Only log action results if not in quiet mode
|
928
|
+
if not quiet and reward is not None:
|
929
|
+
print(f"\n🏆 After action '{action}':")
|
930
|
+
print(f" Reward: {reward}")
|
931
|
+
|
932
|
+
# Print any achievements unlocked
|
933
|
+
achievements_unlocked = []
|
934
|
+
for key, value in obs.get('achievements_status', {}).items():
|
935
|
+
if value:
|
936
|
+
achievements_unlocked.append(key)
|
937
|
+
|
938
|
+
print(f" Achievements unlocked: {achievements_unlocked}")
|
939
|
+
|
940
|
+
# Print inventory (only non-zero items)
|
941
|
+
inventory = obs.get('inventory', {})
|
942
|
+
non_zero_inventory = {k: v for k, v in inventory.items() if v > 0}
|
943
|
+
print(f" Inventory: {non_zero_inventory}")
|
944
|
+
|
945
|
+
# Record step result
|
946
|
+
step_results.append({
|
947
|
+
"turn": turn,
|
948
|
+
"action": action,
|
949
|
+
"reward": reward,
|
950
|
+
"done": done,
|
951
|
+
"info": info
|
952
|
+
})
|
953
|
+
|
954
|
+
# Record environment event for hooks to catch
|
955
|
+
if session_tracer and session_tracer.current_session:
|
956
|
+
# Create environment event with state transition
|
957
|
+
env_event = EnvironmentEvent(
|
958
|
+
time_record=TimeRecord(
|
959
|
+
event_time=datetime.now().isoformat(),
|
960
|
+
message_time=turn
|
961
|
+
),
|
962
|
+
system_instance_id=f"crafter_env_{instance_id}",
|
963
|
+
system_state_before={"public_state": prev_obs},
|
964
|
+
system_state_after={"public_state": obs},
|
965
|
+
reward=reward,
|
966
|
+
terminated=done,
|
967
|
+
metadata={
|
968
|
+
"action": action,
|
969
|
+
"action_int": action_int,
|
970
|
+
"info": info
|
971
|
+
}
|
972
|
+
)
|
973
|
+
session_tracer.record_event(env_event)
|
974
|
+
|
975
|
+
# Also record runtime event for invalid action detection
|
976
|
+
runtime_event = RuntimeEvent(
|
977
|
+
time_record=TimeRecord(
|
978
|
+
event_time=datetime.now().isoformat(),
|
979
|
+
message_time=turn
|
980
|
+
),
|
981
|
+
system_instance_id=f"crafter_runtime_{instance_id}",
|
982
|
+
actions=[action_int],
|
983
|
+
system_state_before=state_before,
|
984
|
+
system_state_after={"observation": obs},
|
985
|
+
metadata={
|
986
|
+
"action_name": action,
|
987
|
+
"action_int": action_int,
|
988
|
+
"reward": reward
|
989
|
+
}
|
990
|
+
)
|
991
|
+
session_tracer.record_event(runtime_event)
|
992
|
+
|
993
|
+
if done:
|
994
|
+
break
|
995
|
+
|
996
|
+
# After all actions in sequence, send final observation message
|
997
|
+
if session_tracer and session_tracer.current_session:
|
998
|
+
obs_msg = create_message(
|
999
|
+
compress_observation_for_trace(obs),
|
1000
|
+
"observation",
|
1001
|
+
f"crafter_env_{instance_id}",
|
1002
|
+
turn + 1
|
1003
|
+
)
|
1004
|
+
session_tracer.record_message(obs_msg)
|
1005
|
+
|
1006
|
+
if done:
|
1007
|
+
break
|
1008
|
+
|
1009
|
+
if progress_bar:
|
1010
|
+
progress_bar.update(1)
|
1011
|
+
|
1012
|
+
# Terminate instance
|
1013
|
+
terminate_response = await retry_http_request(
|
1014
|
+
client, "POST", f"/env/CrafterClassic/terminate",
|
1015
|
+
json={"env_id": instance_id}
|
1016
|
+
)
|
1017
|
+
|
1018
|
+
except Exception as e:
|
1019
|
+
print(f"❌ Episode {episode_id} failed: {e}")
|
1020
|
+
import traceback
|
1021
|
+
traceback.print_exc()
|
1022
|
+
return {
|
1023
|
+
"episode_id": episode_id,
|
1024
|
+
"error": str(e),
|
1025
|
+
"duration": time.time() - episode_start_time
|
1026
|
+
}
|
1027
|
+
|
1028
|
+
# Extract final achievements
|
1029
|
+
final_achievements = []
|
1030
|
+
if obs and 'achievements_status' in obs:
|
1031
|
+
final_achievements = [k for k, v in obs['achievements_status'].items() if v]
|
1032
|
+
|
1033
|
+
# Return results
|
1034
|
+
return {
|
1035
|
+
"episode_id": episode_id,
|
1036
|
+
"total_reward": episode_reward,
|
1037
|
+
"steps": len(step_results),
|
1038
|
+
"termination_reason": termination_reason,
|
1039
|
+
"duration": time.time() - episode_start_time,
|
1040
|
+
"step_results": step_results,
|
1041
|
+
"achievements_unlocked": final_achievements
|
1042
|
+
}
|
1043
|
+
|
1044
|
+
|
1045
|
+
# --- Main ---
|
1046
|
+
async def main():
|
1047
|
+
"""Main entry point with v2 tracing."""
|
1048
|
+
parser = argparse.ArgumentParser(description="Run Crafter evaluation with LM Synth backend")
|
1049
|
+
parser.add_argument("--config", type=str, help="Path to TOML config file")
|
1050
|
+
parser.add_argument("--model", type=str, help="Model name (overrides config)")
|
1051
|
+
parser.add_argument("--episodes", type=int, help="Number of episodes (overrides config)")
|
1052
|
+
parser.add_argument("--max-steps", type=int, help="Max steps per episode (overrides config)")
|
1053
|
+
parser.add_argument("--difficulty", type=str, choices=["easy", "normal", "hard"], help="Difficulty override")
|
1054
|
+
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
|
1055
|
+
parser.add_argument("--quiet", action="store_true", help="Suppress most output except results")
|
1056
|
+
parser.add_argument("--no-traces", action="store_true", help="Disable trace saving")
|
1057
|
+
parser.add_argument("--analyze", action="store_true", help="Analyze traces after running")
|
1058
|
+
parser.add_argument("--skip-warmup", action="store_true", help="Skip model warmup")
|
1059
|
+
|
1060
|
+
args = parser.parse_args()
|
1061
|
+
|
1062
|
+
# Load configuration
|
1063
|
+
config = CrafterConfig(args.config)
|
1064
|
+
|
1065
|
+
# Setup Synth environment variables
|
1066
|
+
setup_synth_environment()
|
1067
|
+
|
1068
|
+
# Clean up old files to keep directory clean
|
1069
|
+
if config.auto_cleanup:
|
1070
|
+
cleanup_old_files()
|
1071
|
+
|
1072
|
+
# Apply command-line overrides
|
1073
|
+
if args.model:
|
1074
|
+
config.model_name = args.model
|
1075
|
+
if args.episodes:
|
1076
|
+
config.num_instances = args.episodes
|
1077
|
+
if args.max_steps:
|
1078
|
+
config.max_turns = args.max_steps
|
1079
|
+
if args.difficulty:
|
1080
|
+
config.difficulty = args.difficulty
|
1081
|
+
if args.verbose:
|
1082
|
+
config.verbose = True
|
1083
|
+
if args.quiet:
|
1084
|
+
config.quiet = True
|
1085
|
+
if not args.verbose: # Don't show this if verbose is also on
|
1086
|
+
print("🔇 Quiet mode enabled - suppressing verbose logs")
|
1087
|
+
else:
|
1088
|
+
config.quiet = False
|
1089
|
+
|
1090
|
+
# Configure logging based on quiet mode
|
1091
|
+
setup_logging(quiet_mode=config.quiet)
|
1092
|
+
|
1093
|
+
if args.no_traces:
|
1094
|
+
config.save_traces = False
|
1095
|
+
config.enable_v2_tracing = False
|
1096
|
+
if args.analyze:
|
1097
|
+
config.analyze_traces = True
|
1098
|
+
if args.skip_warmup:
|
1099
|
+
config.warmup_model = False
|
1100
|
+
|
1101
|
+
# Ensure model is specified
|
1102
|
+
if not config.model_name:
|
1103
|
+
parser.error("Model name must be specified via --model or config file")
|
1104
|
+
|
1105
|
+
print(f"🎮 Crafter ReAct Agent Evaluation (LM with Synth Backend)")
|
1106
|
+
print(f"Model: {config.model_name}")
|
1107
|
+
print(f"Service: {config.service_base_url}")
|
1108
|
+
print(f"Instances: {config.num_instances}")
|
1109
|
+
print(f"Max Turns: {config.max_turns}")
|
1110
|
+
print(f"Difficulty: {config.difficulty}")
|
1111
|
+
print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
1112
|
+
print("=" * 50)
|
1113
|
+
|
1114
|
+
# Test service health
|
1115
|
+
async with AsyncClient(base_url=config.service_base_url) as client:
|
1116
|
+
try:
|
1117
|
+
health_resp = await retry_http_request(client, "GET", "/health")
|
1118
|
+
health_data = health_resp.json()
|
1119
|
+
print(f"✅ Crafter service is healthy: {health_data}")
|
1120
|
+
except Exception as e:
|
1121
|
+
print(f"❌ Failed to connect to Crafter service: {e}")
|
1122
|
+
return
|
1123
|
+
|
1124
|
+
# Warm up the model if requested
|
1125
|
+
if config.warmup_model and not args.skip_warmup:
|
1126
|
+
print(f"\n🔥 Warming up {config.model_name} on Synth backend...")
|
1127
|
+
try:
|
1128
|
+
synth_base_url = os.getenv('SYNTH_BASE_URL') or os.getenv('MODAL_BASE_URL')
|
1129
|
+
synth_api_key = os.getenv('SYNTH_API_KEY') or os.getenv('MODAL_API_KEY')
|
1130
|
+
if synth_base_url and synth_api_key:
|
1131
|
+
synth_config = SynthConfig(
|
1132
|
+
base_url=synth_base_url,
|
1133
|
+
api_key=synth_api_key,
|
1134
|
+
timeout=config.warmup_timeout # Use configurable timeout
|
1135
|
+
)
|
1136
|
+
await warmup_synth_model(config.model_name, synth_config)
|
1137
|
+
print("✅ Model warmed up successfully!")
|
1138
|
+
else:
|
1139
|
+
print("⚠️ Missing SYNTH_BASE_URL or SYNTH_API_KEY, skipping warmup")
|
1140
|
+
except Exception as e:
|
1141
|
+
print(f"⚠️ Warmup failed: {e}")
|
1142
|
+
print("Continuing anyway...")
|
1143
|
+
|
1144
|
+
# Set up v2 tracing if enabled
|
1145
|
+
trace_manager = None
|
1146
|
+
experiment_ctx = None
|
1147
|
+
|
1148
|
+
if config.enable_v2_tracing:
|
1149
|
+
# Create trace directory first
|
1150
|
+
os.makedirs(config.v2_trace_dir, exist_ok=True)
|
1151
|
+
|
1152
|
+
# Initialize trace manager
|
1153
|
+
trace_manager = DuckDBTraceManager(db_path=f"{config.v2_trace_dir}/traces.duckdb")
|
1154
|
+
|
1155
|
+
# Create experiment context
|
1156
|
+
experiment_ctx = create_experiment_context(
|
1157
|
+
db_manager=trace_manager,
|
1158
|
+
experiment_name=f"crafter_lm_synth_{config.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
1159
|
+
description=f"Crafter LM Synth experiment with {config.model_name} on {config.difficulty} difficulty, using LM class"
|
1160
|
+
)
|
1161
|
+
|
1162
|
+
print(f"\n📊 V2 Tracing enabled. Traces will be saved to: {config.v2_trace_dir}")
|
1163
|
+
print(f" Experiment: {experiment_ctx['experiment_name']}")
|
1164
|
+
|
1165
|
+
# Run episodes in parallel using asyncio.gather for better multi-container testing
|
1166
|
+
print(f"\n🚀 Running {config.num_instances} episodes in parallel to test multi-container scaling...")
|
1167
|
+
|
1168
|
+
total_steps = config.num_instances * config.max_turns
|
1169
|
+
episode_seeds = [] # Track seeds used for each episode
|
1170
|
+
|
1171
|
+
# Prepare episode tasks
|
1172
|
+
episode_tasks = []
|
1173
|
+
session_tracers = []
|
1174
|
+
|
1175
|
+
for i in range(config.num_instances):
|
1176
|
+
# Calculate episode seed for logging (simple sequential: 1, 2, 3, etc)
|
1177
|
+
episode_seed = i + 1
|
1178
|
+
episode_seeds.append(episode_seed)
|
1179
|
+
|
1180
|
+
# Create session tracer for this episode if v2 tracing is enabled
|
1181
|
+
session_tracer = None
|
1182
|
+
if config.enable_v2_tracing and trace_manager:
|
1183
|
+
session_tracer = SessionTracer(
|
1184
|
+
traces_dir=config.v2_trace_dir,
|
1185
|
+
hooks=CRAFTER_HOOKS,
|
1186
|
+
duckdb_path=f"{config.v2_trace_dir}/traces.duckdb",
|
1187
|
+
experiment_id=experiment_ctx['experiment_id']
|
1188
|
+
)
|
1189
|
+
|
1190
|
+
# Start session with episode metadata
|
1191
|
+
session_id = f"crafter_episode_{i}_{uuid.uuid4().hex[:8]}"
|
1192
|
+
session_tracer.start_session(session_id)
|
1193
|
+
|
1194
|
+
session_tracers.append(session_tracer)
|
1195
|
+
|
1196
|
+
# Create episode task (but don't await it yet)
|
1197
|
+
episode_task = run_episode(i, config, session_tracer, None, args.quiet) # No progress bar for parallel execution
|
1198
|
+
episode_tasks.append(episode_task)
|
1199
|
+
|
1200
|
+
print(f"📤 Starting {len(episode_tasks)} episodes concurrently...")
|
1201
|
+
start_time = time.time()
|
1202
|
+
|
1203
|
+
# Run all episodes in parallel using asyncio.gather
|
1204
|
+
results = await asyncio.gather(*episode_tasks, return_exceptions=True)
|
1205
|
+
|
1206
|
+
end_time = time.time()
|
1207
|
+
parallel_time = end_time - start_time
|
1208
|
+
|
1209
|
+
print(f"✅ Completed {len(episode_tasks)} episodes in {parallel_time:.2f} seconds")
|
1210
|
+
print(f"📊 Parallel execution throughput: {len(episode_tasks)/parallel_time:.2f} episodes/second")
|
1211
|
+
|
1212
|
+
# Process results and handle any exceptions
|
1213
|
+
successful_results = []
|
1214
|
+
failed_results = []
|
1215
|
+
|
1216
|
+
for i, result in enumerate(results):
|
1217
|
+
if isinstance(result, Exception):
|
1218
|
+
print(f"❌ Episode {i} failed: {result}")
|
1219
|
+
failed_results.append({"episode_id": i, "error": str(result)})
|
1220
|
+
else:
|
1221
|
+
successful_results.append(result)
|
1222
|
+
|
1223
|
+
# End session and save trace
|
1224
|
+
session_tracer = session_tracers[i]
|
1225
|
+
if session_tracer:
|
1226
|
+
# Only save JSON file if not in duckdb_only mode
|
1227
|
+
save_json = not config.duckdb_only
|
1228
|
+
session_tracer.end_session(save=save_json)
|
1229
|
+
|
1230
|
+
# Trace is automatically saved to DuckDB by end_session()
|
1231
|
+
if config.save_traces and config.verbose:
|
1232
|
+
if config.duckdb_only:
|
1233
|
+
print(f"💾 Saved trace for episode {i} to DuckDB only")
|
1234
|
+
else:
|
1235
|
+
print(f"💾 Saved trace for episode {i} to DuckDB and JSON")
|
1236
|
+
|
1237
|
+
# Use successful results for analysis
|
1238
|
+
results = successful_results + failed_results
|
1239
|
+
|
1240
|
+
# Analyze results
|
1241
|
+
print("\n" + "=" * 50)
|
1242
|
+
print("📊 EVALUATION RESULTS")
|
1243
|
+
print("=" * 50)
|
1244
|
+
|
1245
|
+
successful_episodes = [r for r in results if 'error' not in r]
|
1246
|
+
failed_episodes = [r for r in results if 'error' in r]
|
1247
|
+
|
1248
|
+
if successful_episodes:
|
1249
|
+
total_reward = sum(r['total_reward'] for r in successful_episodes)
|
1250
|
+
total_steps = sum(r['steps'] for r in successful_episodes)
|
1251
|
+
avg_reward = total_reward / len(successful_episodes)
|
1252
|
+
avg_steps = total_steps / len(successful_episodes)
|
1253
|
+
|
1254
|
+
print(f"Episodes completed: {len(successful_episodes)}/{config.num_instances}")
|
1255
|
+
print(f"Failed episodes: {len(failed_episodes)}")
|
1256
|
+
print(f"Total reward: {total_reward:.2f}")
|
1257
|
+
print(f"Average reward per episode: {avg_reward:.2f}")
|
1258
|
+
print(f"Total steps: {total_steps}")
|
1259
|
+
print(f"Average steps per episode: {avg_steps:.2f}")
|
1260
|
+
|
1261
|
+
# Show seeds used
|
1262
|
+
if episode_seeds:
|
1263
|
+
print(f"\nSeeds used:")
|
1264
|
+
for i, seed in enumerate(episode_seeds[:len(successful_episodes)]):
|
1265
|
+
print(f" Episode {i}: seed {seed}")
|
1266
|
+
|
1267
|
+
# Extract unique achievements
|
1268
|
+
all_achievements = set()
|
1269
|
+
achievement_counts = defaultdict(int)
|
1270
|
+
|
1271
|
+
for result in successful_episodes:
|
1272
|
+
# Use the achievements_unlocked field we added
|
1273
|
+
if 'achievements_unlocked' in result:
|
1274
|
+
for achievement in result['achievements_unlocked']:
|
1275
|
+
all_achievements.add(achievement)
|
1276
|
+
achievement_counts[achievement] += 1
|
1277
|
+
|
1278
|
+
print(f"Unique achievements unlocked: {len(all_achievements)}")
|
1279
|
+
if all_achievements:
|
1280
|
+
print("\nAchievements unlocked:")
|
1281
|
+
for achievement, count in sorted(achievement_counts.items()):
|
1282
|
+
print(f" - {achievement}: {count} episodes ({count/len(successful_episodes)*100:.1f}%)")
|
1283
|
+
else:
|
1284
|
+
print("No successful episodes completed.")
|
1285
|
+
|
1286
|
+
# Save detailed results to DuckDB if tracing is enabled
|
1287
|
+
if config.save_detailed_results and config.enable_v2_tracing and trace_manager:
|
1288
|
+
# For now, just print that results are available in DuckDB
|
1289
|
+
# The session traces are already saved to DuckDB via the SessionTracer
|
1290
|
+
print(f"\n💾 Results available in DuckDB: {trace_manager.db_path}")
|
1291
|
+
print(f" Experiment ID: {experiment_ctx['experiment_id']}")
|
1292
|
+
print(f" Use DuckDB queries to analyze results")
|
1293
|
+
elif config.save_detailed_results:
|
1294
|
+
# Fallback to JSON if no tracing
|
1295
|
+
results_file = f"crafter_lm_synth_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
1296
|
+
with open(results_file, 'w') as f:
|
1297
|
+
json.dump({
|
1298
|
+
'config': {
|
1299
|
+
'model': config.model_name,
|
1300
|
+
'episodes': config.num_instances,
|
1301
|
+
'max_steps': config.max_turns,
|
1302
|
+
'difficulty': config.difficulty,
|
1303
|
+
'backend': 'synth'
|
1304
|
+
},
|
1305
|
+
'results': results,
|
1306
|
+
'summary': {
|
1307
|
+
'successful_episodes': len(successful_episodes),
|
1308
|
+
'failed_episodes': len(failed_episodes),
|
1309
|
+
'total_reward': total_reward if successful_episodes else 0,
|
1310
|
+
'avg_reward': avg_reward if successful_episodes else 0,
|
1311
|
+
'unique_achievements': list(all_achievements) if successful_episodes else []
|
1312
|
+
}
|
1313
|
+
}, f, indent=2)
|
1314
|
+
print(f"\n💾 Detailed results saved to: {results_file}")
|
1315
|
+
|
1316
|
+
|
1317
|
+
if __name__ == "__main__":
|
1318
|
+
asyncio.run(main())
|