synth-ai 0.2.9.dev3__py3-none-any.whl → 0.2.9.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +17 -0
- examples/common_old/backend.py +21 -0
- examples/crafter_debug_render.py +180 -0
- examples/evals_old/README.md +98 -0
- examples/evals_old/__init__.py +6 -0
- examples/evals_old/compare_models.py +1037 -0
- examples/evals_old/example_log.md +145 -0
- examples/evals_old/run_demo.sh +126 -0
- examples/evals_old/trace_analysis.py +270 -0
- examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
- examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
- examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
- examples/finetuning_old/synth_qwen_v1/README.md +68 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
- examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
- examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
- examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
- examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
- examples/finetuning_old/synth_qwen_v1/util.py +147 -0
- examples/rl/README.md +169 -0
- examples/rl/configs/eval_base_qwen.toml +15 -0
- examples/rl/configs/eval_rl_qwen.toml +11 -0
- examples/rl/configs/rl_from_base_qwen.toml +35 -0
- examples/rl/configs/rl_from_base_qwen17.toml +74 -0
- examples/rl/configs/rl_from_ft_qwen.toml +35 -0
- examples/rl/download_dataset.py +64 -0
- examples/rl/run_eval.py +435 -0
- examples/rl/run_rl_and_save.py +94 -0
- examples/rl/task_app/README.md +22 -0
- {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
- examples/rl/task_app/math_task_app.py +107 -0
- examples/rl_old/task_app.py +962 -0
- examples/run_crafter_demo.sh +10 -0
- examples/warming_up_to_rl/analyze_trace_db.py +420 -0
- examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
- examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
- examples/warming_up_to_rl/export_trace_sft.py +541 -0
- examples/warming_up_to_rl/groq_test.py +88 -0
- examples/warming_up_to_rl/manage_secrets.py +127 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/readme.md +172 -0
- examples/warming_up_to_rl/run_eval.py +434 -0
- examples/warming_up_to_rl/run_fft_and_save.py +309 -0
- examples/warming_up_to_rl/run_local_rollout.py +188 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
- examples/warming_up_to_rl/run_rl_and_save.py +101 -0
- examples/warming_up_to_rl/run_rollout_remote.py +129 -0
- examples/warming_up_to_rl/task_app/README.md +38 -0
- {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
- synth_ai/api/train/config_finder.py +18 -18
- synth_ai/api/train/env_resolver.py +28 -1
- synth_ai/cli/task_apps.py +291 -56
- synth_ai/task/apps/__init__.py +54 -13
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/RECORD +106 -13
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/top_level.txt +1 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1037 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# ruff: noqa: E402
|
|
3
|
+
"""
|
|
4
|
+
Comprehensive script to run Crafter rollouts for multiple models and compare their performance.
|
|
5
|
+
Updated to use tracing_v3 with async architecture.
|
|
6
|
+
|
|
7
|
+
Runs experiments for:
|
|
8
|
+
- gpt-4o-mini
|
|
9
|
+
- gpt-4.1-mini
|
|
10
|
+
- gpt-4.1-nano
|
|
11
|
+
- gemini-1.5-flash
|
|
12
|
+
- gemini-2.5-flash-lite
|
|
13
|
+
- qwen3/32b
|
|
14
|
+
|
|
15
|
+
Analyzes and compares:
|
|
16
|
+
- Invalid action rates
|
|
17
|
+
- Achievement frequencies by step
|
|
18
|
+
- Achievement counts across models
|
|
19
|
+
- Performance metrics
|
|
20
|
+
- Cost analysis
|
|
21
|
+
"""
|
|
22
|
+
import os
|
|
23
|
+
import sys
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
|
|
26
|
+
# Ensure repository root is on sys.path before importing synth_ai
|
|
27
|
+
synth_ai_root = Path(__file__).parent.parent.parent
|
|
28
|
+
sys.path.insert(0, str(synth_ai_root))
|
|
29
|
+
|
|
30
|
+
# Disable v1 logging to see v3 tracing clearly
|
|
31
|
+
os.environ["LANGFUSE_ENABLED"] = "false"
|
|
32
|
+
os.environ["SYNTH_LOGGING"] = "false"
|
|
33
|
+
|
|
34
|
+
import argparse
|
|
35
|
+
import asyncio
|
|
36
|
+
import contextlib
|
|
37
|
+
import json
|
|
38
|
+
import logging
|
|
39
|
+
import random
|
|
40
|
+
import time
|
|
41
|
+
from collections import defaultdict
|
|
42
|
+
from datetime import datetime
|
|
43
|
+
from typing import Any
|
|
44
|
+
|
|
45
|
+
import httpx
|
|
46
|
+
import numpy as np
|
|
47
|
+
import pandas as pd
|
|
48
|
+
|
|
49
|
+
# Import enhanced LM with v3 tracing
|
|
50
|
+
from synth_ai.lm.core.main_v3 import LM
|
|
51
|
+
from synth_ai.tracing_v3.abstractions import (
|
|
52
|
+
EnvironmentEvent,
|
|
53
|
+
RuntimeEvent,
|
|
54
|
+
SessionEventMarkovBlanketMessage,
|
|
55
|
+
TimeRecord,
|
|
56
|
+
)
|
|
57
|
+
from synth_ai.tracing_v3.db_config import get_default_db_config
|
|
58
|
+
from synth_ai.tracing_v3.decorators import set_turn_number
|
|
59
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
60
|
+
from synth_ai.tracing_v3.turso.manager import AsyncSQLTraceManager
|
|
61
|
+
from tqdm import tqdm
|
|
62
|
+
|
|
63
|
+
# Disable httpx logging for cleaner output
|
|
64
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
65
|
+
|
|
66
|
+
# Import Crafter hooks
|
|
67
|
+
try:
|
|
68
|
+
from synth_ai.environments.examples.crafter_classic.trace_hooks_v3 import CRAFTER_HOOKS
|
|
69
|
+
|
|
70
|
+
print(f"✅ Loaded {len(CRAFTER_HOOKS.hooks)} Crafter achievement hooks (Easy, Medium, Hard)")
|
|
71
|
+
except ImportError:
|
|
72
|
+
print("Warning: Could not import CRAFTER_HOOKS for v3")
|
|
73
|
+
from synth_ai.tracing_v3.hooks import HookManager
|
|
74
|
+
|
|
75
|
+
CRAFTER_HOOKS = HookManager()
|
|
76
|
+
|
|
77
|
+
# Global buckets for sessions
|
|
78
|
+
_SESSIONS: dict[str, tuple[str, object]] = {} # session_id -> (experiment_id, trace)
|
|
79
|
+
|
|
80
|
+
# Configuration
|
|
81
|
+
MODELS_TO_TEST = [
|
|
82
|
+
"gpt-5-nano",
|
|
83
|
+
"gpt-4.1-nano",
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
# Service URLs (modify these based on your setup)
|
|
87
|
+
CRAFTER_SERVICE_URL = "http://localhost:8901"
|
|
88
|
+
|
|
89
|
+
# Database configuration - uses the centralized config which matches serve.sh
|
|
90
|
+
db_config = get_default_db_config()
|
|
91
|
+
DATABASE_URL = db_config.database_url
|
|
92
|
+
|
|
93
|
+
# Retry configuration for HTTP requests
|
|
94
|
+
MAX_RETRIES = 3
|
|
95
|
+
BASE_DELAY = 0.1
|
|
96
|
+
MAX_DELAY = 2.0
|
|
97
|
+
HTTP_TIMEOUT = 30.0
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class ExperimentConfig:
|
|
101
|
+
"""Configuration for the multi-model experiment."""
|
|
102
|
+
|
|
103
|
+
def __init__(self):
|
|
104
|
+
self.num_episodes = 10 # Number of episodes per model
|
|
105
|
+
self.max_turns = 100 # Max turns per episode
|
|
106
|
+
self.difficulty = "easy"
|
|
107
|
+
self.save_traces = True
|
|
108
|
+
self.verbose = True
|
|
109
|
+
self.quiet = False # Default to verbose mode
|
|
110
|
+
self.enable_v3_tracing = True
|
|
111
|
+
self.v3_trace_dir = "./traces/v3/crafter_comparison"
|
|
112
|
+
self.crafter_service_url = CRAFTER_SERVICE_URL
|
|
113
|
+
self.database_url = DATABASE_URL
|
|
114
|
+
self.base_seed = 1000 # Base seed for episode generation
|
|
115
|
+
self.turn_timeout = 30.0 # Timeout per turn in seconds
|
|
116
|
+
self.episode_timeout = 300.0 # Total timeout per episode in seconds
|
|
117
|
+
self.concurrency = 5 # Max concurrent episodes per model
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
async def retry_http_request(client: httpx.AsyncClient, method: str, url: str, **kwargs) -> Any:
|
|
121
|
+
"""Retry HTTP requests with exponential backoff and jitter."""
|
|
122
|
+
last_exception = None
|
|
123
|
+
|
|
124
|
+
for attempt in range(MAX_RETRIES):
|
|
125
|
+
try:
|
|
126
|
+
if attempt > 0:
|
|
127
|
+
delay = min(BASE_DELAY * (2 ** (attempt - 1)), MAX_DELAY)
|
|
128
|
+
jitter = random.uniform(0, 0.1 * delay)
|
|
129
|
+
total_delay = delay + jitter
|
|
130
|
+
await asyncio.sleep(total_delay)
|
|
131
|
+
|
|
132
|
+
response = await client.request(method, url, timeout=HTTP_TIMEOUT, **kwargs)
|
|
133
|
+
|
|
134
|
+
if response.status_code < 500:
|
|
135
|
+
return response
|
|
136
|
+
|
|
137
|
+
last_exception = Exception(f"HTTP {response.status_code}: {response.text}")
|
|
138
|
+
|
|
139
|
+
except httpx.ConnectError as e:
|
|
140
|
+
last_exception = Exception(f"Connection failed to {url}: {e}")
|
|
141
|
+
if attempt < MAX_RETRIES - 1:
|
|
142
|
+
await asyncio.sleep(1.0 * (2**attempt))
|
|
143
|
+
except httpx.ReadError as e:
|
|
144
|
+
last_exception = e
|
|
145
|
+
if attempt < MAX_RETRIES - 1:
|
|
146
|
+
read_error_delay = min(1.0 * (2**attempt), 5.0)
|
|
147
|
+
await asyncio.sleep(read_error_delay)
|
|
148
|
+
except Exception as e:
|
|
149
|
+
last_exception = e
|
|
150
|
+
|
|
151
|
+
print(f" ❌ HTTP request failed after {MAX_RETRIES} attempts: {method} {url}")
|
|
152
|
+
print(f" ❌ Error: {type(last_exception).__name__}: {str(last_exception)[:200]}")
|
|
153
|
+
raise last_exception
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# Crafter action mapping
|
|
157
|
+
CRAFTER_ACTIONS = {
|
|
158
|
+
"noop": 0,
|
|
159
|
+
"move_left": 1,
|
|
160
|
+
"move_right": 2,
|
|
161
|
+
"move_up": 3,
|
|
162
|
+
"move_down": 4,
|
|
163
|
+
"do": 5,
|
|
164
|
+
"sleep": 6,
|
|
165
|
+
"place_stone": 7,
|
|
166
|
+
"place_table": 8,
|
|
167
|
+
"place_furnace": 9,
|
|
168
|
+
"place_plant": 10,
|
|
169
|
+
"make_wood_pickaxe": 11,
|
|
170
|
+
"make_stone_pickaxe": 12,
|
|
171
|
+
"make_iron_pickaxe": 13,
|
|
172
|
+
"make_wood_sword": 14,
|
|
173
|
+
"make_stone_sword": 15,
|
|
174
|
+
"make_iron_sword": 16,
|
|
175
|
+
"eat_cow": 17,
|
|
176
|
+
"eat_plant": 18,
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
# Create reverse mapping for validation
|
|
180
|
+
INT_TO_ACTION_STRING = {v: k for k, v in CRAFTER_ACTIONS.items()}
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def compress_observation_for_trace(obs: dict[str, Any]) -> str:
|
|
184
|
+
"""Compress observation data for storage in traces."""
|
|
185
|
+
try:
|
|
186
|
+
return json.dumps(
|
|
187
|
+
{
|
|
188
|
+
"inv": {k: v for k, v in obs.get("inventory", {}).items() if v > 0},
|
|
189
|
+
"nearby": obs.get("nearby", []),
|
|
190
|
+
"hp": obs.get("status", {}).get("health", 0),
|
|
191
|
+
"food": obs.get("status", {}).get("food", 0),
|
|
192
|
+
"ach": sum(1 for v in obs.get("achievements_status", {}).values() if v),
|
|
193
|
+
},
|
|
194
|
+
separators=(",", ":"),
|
|
195
|
+
)
|
|
196
|
+
except Exception as e:
|
|
197
|
+
return f'{{"error": "{str(e)}"}}'
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def create_message(
|
|
201
|
+
content: str, message_type: str, system_id: str, turn: int
|
|
202
|
+
) -> SessionEventMarkovBlanketMessage:
|
|
203
|
+
"""Create a SessionEventMarkovBlanketMessage representing cross-boundary communication."""
|
|
204
|
+
return SessionEventMarkovBlanketMessage(
|
|
205
|
+
content=content,
|
|
206
|
+
message_type=message_type,
|
|
207
|
+
metadata={"system_id": system_id, "turn": turn},
|
|
208
|
+
time_record=TimeRecord(event_time=time.time(), message_time=turn),
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
async def run_episode(
|
|
213
|
+
config: ExperimentConfig,
|
|
214
|
+
model_name: str,
|
|
215
|
+
episode_num: int,
|
|
216
|
+
experiment_id: str,
|
|
217
|
+
pbar: tqdm | None = None,
|
|
218
|
+
) -> dict[str, Any]:
|
|
219
|
+
"""Run a single episode with a specific model using v3 tracing."""
|
|
220
|
+
# Create a new session tracer for this episode
|
|
221
|
+
session_tracer = SessionTracer(hooks=CRAFTER_HOOKS, db_url=config.database_url)
|
|
222
|
+
|
|
223
|
+
# Start session with metadata
|
|
224
|
+
session_id = await session_tracer.start_session(
|
|
225
|
+
metadata={
|
|
226
|
+
"model": model_name,
|
|
227
|
+
"episode": episode_num,
|
|
228
|
+
"experiment_id": experiment_id,
|
|
229
|
+
"difficulty": config.difficulty,
|
|
230
|
+
}
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Started tracing session (output disabled for clean UI)
|
|
234
|
+
|
|
235
|
+
# Store session in global bucket
|
|
236
|
+
_SESSIONS[session_id] = (experiment_id, session_tracer)
|
|
237
|
+
|
|
238
|
+
# Initialize LM with session tracer
|
|
239
|
+
lm = LM(
|
|
240
|
+
vendor="openai",
|
|
241
|
+
model=model_name,
|
|
242
|
+
temperature=0.1, # Low temperature for more consistent gameplay
|
|
243
|
+
session_tracer=session_tracer,
|
|
244
|
+
system_id=f"crafter_agent_{model_name}",
|
|
245
|
+
enable_v3_tracing=True,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Create HTTP client
|
|
249
|
+
async with httpx.AsyncClient() as client:
|
|
250
|
+
try:
|
|
251
|
+
# Initialize environment with consecutive seed
|
|
252
|
+
seed = (
|
|
253
|
+
config.base_seed + episode_num
|
|
254
|
+
) # Base seed + episode number for consecutive seeds
|
|
255
|
+
request_data = {"config": {"difficulty": config.difficulty, "seed": seed}}
|
|
256
|
+
init_response = await retry_http_request(
|
|
257
|
+
client,
|
|
258
|
+
"POST",
|
|
259
|
+
f"{config.crafter_service_url}/env/CrafterClassic/initialize",
|
|
260
|
+
json=request_data,
|
|
261
|
+
)
|
|
262
|
+
init_data = init_response.json()
|
|
263
|
+
|
|
264
|
+
# Debug the response format (removed for clean output)
|
|
265
|
+
|
|
266
|
+
# Handle different possible response formats
|
|
267
|
+
if "instance_id" in init_data:
|
|
268
|
+
instance_id = init_data["instance_id"]
|
|
269
|
+
elif "env_id" in init_data:
|
|
270
|
+
instance_id = init_data["env_id"]
|
|
271
|
+
elif "id" in init_data:
|
|
272
|
+
instance_id = init_data["id"]
|
|
273
|
+
else:
|
|
274
|
+
# If none of the expected keys exist, print the response and raise a clear error
|
|
275
|
+
print(f"❌ Unexpected response format from Crafter service: {init_data}")
|
|
276
|
+
raise KeyError(
|
|
277
|
+
f"Could not find environment ID in response. Available keys: {list(init_data.keys())}"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Get initial observation (from initialize response)
|
|
281
|
+
obs = init_data["observation"]
|
|
282
|
+
|
|
283
|
+
prev_obs = obs
|
|
284
|
+
done = False
|
|
285
|
+
invalid_actions = 0
|
|
286
|
+
total_actions = 0
|
|
287
|
+
episode_start_time = time.time()
|
|
288
|
+
|
|
289
|
+
for turn in range(config.max_turns):
|
|
290
|
+
if done:
|
|
291
|
+
break
|
|
292
|
+
|
|
293
|
+
# Check episode timeout
|
|
294
|
+
if time.time() - episode_start_time > config.episode_timeout:
|
|
295
|
+
print(f" ⏰ Episode {episode_num} timed out after {config.episode_timeout}s")
|
|
296
|
+
done = True
|
|
297
|
+
break
|
|
298
|
+
|
|
299
|
+
# Progress bar will be updated at end of turn
|
|
300
|
+
|
|
301
|
+
set_turn_number(turn)
|
|
302
|
+
|
|
303
|
+
# Start timestep for this turn
|
|
304
|
+
await session_tracer.start_timestep(f"turn_{turn}")
|
|
305
|
+
|
|
306
|
+
# Prepare context for the agent
|
|
307
|
+
inventory_str = ", ".join(
|
|
308
|
+
[f"{k}: {v}" for k, v in obs.get("inventory", {}).items() if v > 0]
|
|
309
|
+
)
|
|
310
|
+
if not inventory_str:
|
|
311
|
+
inventory_str = "empty"
|
|
312
|
+
|
|
313
|
+
nearby_str = ", ".join(obs.get("nearby", []))
|
|
314
|
+
if not nearby_str:
|
|
315
|
+
nearby_str = "nothing"
|
|
316
|
+
|
|
317
|
+
status = obs.get("status", {})
|
|
318
|
+
health = status.get("health", 0)
|
|
319
|
+
hunger = status.get("food", 0)
|
|
320
|
+
|
|
321
|
+
# Get more detailed game state
|
|
322
|
+
position = obs.get("position", [0, 0])
|
|
323
|
+
achievements = obs.get("achievements_status", {})
|
|
324
|
+
unlocked = [name for name, status in achievements.items() if status]
|
|
325
|
+
achievements_str = ", ".join(unlocked) if unlocked else "none"
|
|
326
|
+
|
|
327
|
+
# Get semantic map if available
|
|
328
|
+
semantic_map = obs.get("semantic_map", None)
|
|
329
|
+
map_str = ""
|
|
330
|
+
if semantic_map is not None:
|
|
331
|
+
# Simple 5x5 view around player
|
|
332
|
+
try:
|
|
333
|
+
px, py = position
|
|
334
|
+
view_size = 5
|
|
335
|
+
half = view_size // 2
|
|
336
|
+
map_lines = []
|
|
337
|
+
for dy in range(-half, half + 1):
|
|
338
|
+
row = []
|
|
339
|
+
for dx in range(-half, half + 1):
|
|
340
|
+
x, y = px + dx, py + dy
|
|
341
|
+
if dx == 0 and dy == 0:
|
|
342
|
+
row.append("@") # Player
|
|
343
|
+
elif 0 <= x < len(semantic_map) and 0 <= y < len(semantic_map[0]):
|
|
344
|
+
cell = semantic_map[x][y]
|
|
345
|
+
# Map common items
|
|
346
|
+
if cell == 0:
|
|
347
|
+
row.append(".") # Empty/grass
|
|
348
|
+
elif cell == 1:
|
|
349
|
+
row.append("T") # Tree
|
|
350
|
+
elif cell == 2:
|
|
351
|
+
row.append("S") # Stone
|
|
352
|
+
elif cell == 3:
|
|
353
|
+
row.append("C") # Cow
|
|
354
|
+
elif cell == 4:
|
|
355
|
+
row.append("W") # Water
|
|
356
|
+
else:
|
|
357
|
+
row.append("?")
|
|
358
|
+
else:
|
|
359
|
+
row.append("#") # Out of bounds
|
|
360
|
+
map_lines.append(" ".join(row))
|
|
361
|
+
map_str = "\nMap (5x5 view, @ = you):\n" + "\n".join(map_lines)
|
|
362
|
+
except Exception:
|
|
363
|
+
map_str = "\nMap view unavailable"
|
|
364
|
+
|
|
365
|
+
# Create agent prompt
|
|
366
|
+
prompt = f"""Game State (Turn {turn}):
|
|
367
|
+
- Position: {position}
|
|
368
|
+
- Health: {health}/9
|
|
369
|
+
- Hunger: {hunger}/9
|
|
370
|
+
- Inventory: {inventory_str}
|
|
371
|
+
- Nearby objects: {nearby_str}
|
|
372
|
+
- Achievements unlocked: {achievements_str}
|
|
373
|
+
{map_str}
|
|
374
|
+
|
|
375
|
+
Choose your next actions based on what you see. Use the 'interact' tool with a list of action IDs.
|
|
376
|
+
|
|
377
|
+
Tips:
|
|
378
|
+
- Look at the map! T=tree (wood), S=stone, C=cow (food), W=water
|
|
379
|
+
- To collect resources: move to them (actions 1-4) then use action 5 (do)
|
|
380
|
+
- To craft: place table (8) first, then craft tools (11-16)
|
|
381
|
+
- If hungry and see cow (C), move to it and eat (17)
|
|
382
|
+
|
|
383
|
+
What actions do you want to take?"""
|
|
384
|
+
|
|
385
|
+
# Send observation as message
|
|
386
|
+
obs_msg = create_message(
|
|
387
|
+
f"Observation: {compress_observation_for_trace(obs)}",
|
|
388
|
+
"system",
|
|
389
|
+
f"crafter_env_{instance_id}",
|
|
390
|
+
turn,
|
|
391
|
+
)
|
|
392
|
+
await session_tracer.record_message(
|
|
393
|
+
content=obs_msg.content,
|
|
394
|
+
message_type=obs_msg.message_type,
|
|
395
|
+
event_time=obs_msg.time_record.event_time,
|
|
396
|
+
message_time=obs_msg.time_record.message_time,
|
|
397
|
+
metadata=obs_msg.metadata,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Get action from LM with tools (with timeout)
|
|
401
|
+
try:
|
|
402
|
+
# Define the interact tool for Crafter
|
|
403
|
+
from pydantic import BaseModel, Field
|
|
404
|
+
from synth_ai.lm.tools.base import BaseTool
|
|
405
|
+
|
|
406
|
+
class InteractArgs(BaseModel):
|
|
407
|
+
actions: list[int] = Field(..., description="List of action IDs to execute")
|
|
408
|
+
|
|
409
|
+
interact_tool = BaseTool(
|
|
410
|
+
name="interact",
|
|
411
|
+
arguments=InteractArgs,
|
|
412
|
+
description="Execute actions in the Crafter game",
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# Create system message that explains available actions
|
|
416
|
+
action_list = "\n".join(
|
|
417
|
+
[f"{action_id}: {action}" for action, action_id in CRAFTER_ACTIONS.items()]
|
|
418
|
+
)
|
|
419
|
+
system_message = f"""You are an agent playing Crafter, a 2D survival game. Your goal is to survive and unlock achievements.
|
|
420
|
+
|
|
421
|
+
You MUST use the 'interact' tool to execute actions. The tool takes a list of action IDs.
|
|
422
|
+
|
|
423
|
+
Action ID mapping:
|
|
424
|
+
{action_list}
|
|
425
|
+
|
|
426
|
+
Strategy tips:
|
|
427
|
+
- Start by collecting wood (move to trees and use action 5)
|
|
428
|
+
- Place a crafting table (action 8) to unlock crafting recipes
|
|
429
|
+
- Craft tools to collect resources more efficiently
|
|
430
|
+
- Eat when hungry, sleep when tired
|
|
431
|
+
- Explore to find different resources
|
|
432
|
+
|
|
433
|
+
IMPORTANT: Always use the 'interact' tool with a list of action IDs. For example: interact(actions=[2, 2, 5]) to move right twice and collect."""
|
|
434
|
+
|
|
435
|
+
# Get actions from LM using tools with timeout
|
|
436
|
+
try:
|
|
437
|
+
action_response = await asyncio.wait_for(
|
|
438
|
+
lm.respond_async(
|
|
439
|
+
system_message=system_message,
|
|
440
|
+
user_message=prompt,
|
|
441
|
+
tools=[interact_tool],
|
|
442
|
+
turn_number=turn,
|
|
443
|
+
),
|
|
444
|
+
timeout=config.turn_timeout,
|
|
445
|
+
)
|
|
446
|
+
except TimeoutError:
|
|
447
|
+
print(
|
|
448
|
+
f" ⏰ Turn {turn} timed out for episode {episode_num} after {config.turn_timeout}s"
|
|
449
|
+
)
|
|
450
|
+
action_response = None
|
|
451
|
+
done = True
|
|
452
|
+
break
|
|
453
|
+
|
|
454
|
+
# Debug: print response (removed for clean output)
|
|
455
|
+
|
|
456
|
+
# Extract tool calls from response
|
|
457
|
+
if hasattr(action_response, "tool_calls") and action_response.tool_calls:
|
|
458
|
+
tool_calls = action_response.tool_calls
|
|
459
|
+
|
|
460
|
+
# Process each tool call
|
|
461
|
+
for tool_call in tool_calls:
|
|
462
|
+
if tool_call.get("function", {}).get("name") == "interact":
|
|
463
|
+
# Extract actions from the tool call
|
|
464
|
+
import json
|
|
465
|
+
|
|
466
|
+
args = json.loads(
|
|
467
|
+
tool_call.get("function", {}).get("arguments", "{}")
|
|
468
|
+
)
|
|
469
|
+
actions = args.get("actions", [])
|
|
470
|
+
|
|
471
|
+
if not actions:
|
|
472
|
+
# If no actions provided, use noop
|
|
473
|
+
actions = [0]
|
|
474
|
+
|
|
475
|
+
# Execute each action separately
|
|
476
|
+
for i, action_id in enumerate(actions):
|
|
477
|
+
# Capture BEFORE frame
|
|
478
|
+
frame_before_b64 = None
|
|
479
|
+
try:
|
|
480
|
+
fr = await retry_http_request(
|
|
481
|
+
client,
|
|
482
|
+
"GET",
|
|
483
|
+
f"{config.crafter_service_url}/env/CrafterClassic/frame",
|
|
484
|
+
params={"env_id": instance_id},
|
|
485
|
+
)
|
|
486
|
+
if fr.status_code == 200:
|
|
487
|
+
frame_before_b64 = fr.json().get("image_base64")
|
|
488
|
+
except Exception:
|
|
489
|
+
frame_before_b64 = None
|
|
490
|
+
total_actions += 1
|
|
491
|
+
|
|
492
|
+
# Validate action ID
|
|
493
|
+
if action_id not in INT_TO_ACTION_STRING:
|
|
494
|
+
# Invalid action logging removed for clean output
|
|
495
|
+
action_id = 0
|
|
496
|
+
invalid_actions += 1
|
|
497
|
+
|
|
498
|
+
# Send action to Crafter service with timeout
|
|
499
|
+
try:
|
|
500
|
+
step_response = await asyncio.wait_for(
|
|
501
|
+
retry_http_request(
|
|
502
|
+
client,
|
|
503
|
+
"POST",
|
|
504
|
+
f"{config.crafter_service_url}/env/CrafterClassic/step",
|
|
505
|
+
json={
|
|
506
|
+
"env_id": instance_id,
|
|
507
|
+
"action": {
|
|
508
|
+
"tool_calls": [
|
|
509
|
+
{
|
|
510
|
+
"tool": "interact",
|
|
511
|
+
"args": {"action": action_id},
|
|
512
|
+
}
|
|
513
|
+
]
|
|
514
|
+
},
|
|
515
|
+
},
|
|
516
|
+
),
|
|
517
|
+
timeout=5.0, # 5 second timeout for individual action
|
|
518
|
+
)
|
|
519
|
+
except TimeoutError:
|
|
520
|
+
print(
|
|
521
|
+
f" ⏰ Action execution timed out in episode {episode_num}"
|
|
522
|
+
)
|
|
523
|
+
done = True
|
|
524
|
+
break
|
|
525
|
+
|
|
526
|
+
if step_response.status_code != 200:
|
|
527
|
+
print(
|
|
528
|
+
f" ❌ Step failed: {step_response.status_code} - {step_response.text}"
|
|
529
|
+
)
|
|
530
|
+
done = True
|
|
531
|
+
break
|
|
532
|
+
|
|
533
|
+
step_data = step_response.json()
|
|
534
|
+
|
|
535
|
+
# Extract data from response
|
|
536
|
+
new_obs = step_data["observation"]
|
|
537
|
+
reward = step_data["reward"]
|
|
538
|
+
done = step_data["done"]
|
|
539
|
+
|
|
540
|
+
# Record runtime event for action
|
|
541
|
+
action_name = INT_TO_ACTION_STRING.get(action_id, "unknown")
|
|
542
|
+
runtime_event = RuntimeEvent(
|
|
543
|
+
system_instance_id=f"crafter_env_{instance_id}",
|
|
544
|
+
time_record=TimeRecord(
|
|
545
|
+
event_time=time.time(), message_time=turn
|
|
546
|
+
),
|
|
547
|
+
actions=[action_id],
|
|
548
|
+
metadata={
|
|
549
|
+
"action_name": action_name,
|
|
550
|
+
"valid": action_name != "noop" or invalid_actions == 0,
|
|
551
|
+
},
|
|
552
|
+
)
|
|
553
|
+
await session_tracer.record_event(runtime_event)
|
|
554
|
+
|
|
555
|
+
# Capture AFTER frame
|
|
556
|
+
frame_after_b64 = None
|
|
557
|
+
try:
|
|
558
|
+
fr = await retry_http_request(
|
|
559
|
+
client,
|
|
560
|
+
"GET",
|
|
561
|
+
f"{config.crafter_service_url}/env/CrafterClassic/frame",
|
|
562
|
+
params={"env_id": instance_id},
|
|
563
|
+
)
|
|
564
|
+
if fr.status_code == 200:
|
|
565
|
+
frame_after_b64 = fr.json().get("image_base64")
|
|
566
|
+
except Exception:
|
|
567
|
+
frame_after_b64 = None
|
|
568
|
+
|
|
569
|
+
# Save frames to assets and compute URIs
|
|
570
|
+
before_uri = None
|
|
571
|
+
after_uri = None
|
|
572
|
+
try:
|
|
573
|
+
if frame_before_b64:
|
|
574
|
+
import base64
|
|
575
|
+
from pathlib import Path
|
|
576
|
+
|
|
577
|
+
assets_dir = Path("traces/v3/assets") / session_id
|
|
578
|
+
assets_dir.mkdir(parents=True, exist_ok=True)
|
|
579
|
+
before_path = assets_dir / f"{turn}_{i}_before.png"
|
|
580
|
+
with open(before_path, "wb") as f:
|
|
581
|
+
f.write(base64.b64decode(frame_before_b64))
|
|
582
|
+
before_uri = str(before_path)
|
|
583
|
+
if frame_after_b64:
|
|
584
|
+
import base64
|
|
585
|
+
from pathlib import Path
|
|
586
|
+
|
|
587
|
+
assets_dir = Path("traces/v3/assets") / session_id
|
|
588
|
+
assets_dir.mkdir(parents=True, exist_ok=True)
|
|
589
|
+
after_path = assets_dir / f"{turn}_{i}_after.png"
|
|
590
|
+
with open(after_path, "wb") as f:
|
|
591
|
+
f.write(base64.b64decode(frame_after_b64))
|
|
592
|
+
after_uri = str(after_path)
|
|
593
|
+
except Exception:
|
|
594
|
+
before_uri = None
|
|
595
|
+
after_uri = None
|
|
596
|
+
|
|
597
|
+
# Record environment event with visuals
|
|
598
|
+
env_event = EnvironmentEvent(
|
|
599
|
+
system_instance_id=f"crafter_env_{instance_id}",
|
|
600
|
+
time_record=TimeRecord(
|
|
601
|
+
event_time=time.time(), message_time=turn
|
|
602
|
+
),
|
|
603
|
+
reward=reward,
|
|
604
|
+
terminated=done,
|
|
605
|
+
system_state_before={
|
|
606
|
+
"observation": prev_obs,
|
|
607
|
+
**(
|
|
608
|
+
{"visuals": {"frame_uri": before_uri}}
|
|
609
|
+
if before_uri
|
|
610
|
+
else {}
|
|
611
|
+
),
|
|
612
|
+
},
|
|
613
|
+
system_state_after={
|
|
614
|
+
"observation": new_obs,
|
|
615
|
+
"public_state": {
|
|
616
|
+
"achievements_status": new_obs.get(
|
|
617
|
+
"achievements_status", {}
|
|
618
|
+
)
|
|
619
|
+
},
|
|
620
|
+
**(
|
|
621
|
+
{"visuals": {"frame_uri": after_uri}}
|
|
622
|
+
if after_uri
|
|
623
|
+
else {}
|
|
624
|
+
),
|
|
625
|
+
},
|
|
626
|
+
)
|
|
627
|
+
await session_tracer.record_event(env_event)
|
|
628
|
+
|
|
629
|
+
# Update for next turn
|
|
630
|
+
prev_obs = obs
|
|
631
|
+
obs = new_obs
|
|
632
|
+
|
|
633
|
+
if done:
|
|
634
|
+
break
|
|
635
|
+
|
|
636
|
+
# Per-episode progress updated once per turn (not per action)
|
|
637
|
+
else:
|
|
638
|
+
# No tool calls provided, use noop
|
|
639
|
+
action_id = 0
|
|
640
|
+
total_actions += 1
|
|
641
|
+
invalid_actions += 1
|
|
642
|
+
|
|
643
|
+
# Send noop action with timeout
|
|
644
|
+
try:
|
|
645
|
+
step_response = await asyncio.wait_for(
|
|
646
|
+
retry_http_request(
|
|
647
|
+
client,
|
|
648
|
+
"POST",
|
|
649
|
+
f"{config.crafter_service_url}/env/CrafterClassic/step",
|
|
650
|
+
json={
|
|
651
|
+
"env_id": instance_id,
|
|
652
|
+
"action": {
|
|
653
|
+
"tool_calls": [
|
|
654
|
+
{"tool": "interact", "args": {"action": action_id}}
|
|
655
|
+
]
|
|
656
|
+
},
|
|
657
|
+
},
|
|
658
|
+
),
|
|
659
|
+
timeout=5.0, # 5 second timeout
|
|
660
|
+
)
|
|
661
|
+
except TimeoutError:
|
|
662
|
+
print(f" ⏰ Noop action timed out in episode {episode_num}")
|
|
663
|
+
done = True
|
|
664
|
+
break
|
|
665
|
+
|
|
666
|
+
if step_response.status_code != 200:
|
|
667
|
+
print(
|
|
668
|
+
f" ❌ Step failed: {step_response.status_code} - {step_response.text}"
|
|
669
|
+
)
|
|
670
|
+
done = True
|
|
671
|
+
else:
|
|
672
|
+
step_data = step_response.json()
|
|
673
|
+
new_obs = step_data["observation"]
|
|
674
|
+
reward = step_data["reward"]
|
|
675
|
+
done = step_data["done"]
|
|
676
|
+
|
|
677
|
+
# Update observation
|
|
678
|
+
prev_obs = obs
|
|
679
|
+
obs = new_obs
|
|
680
|
+
|
|
681
|
+
# End timestep
|
|
682
|
+
await session_tracer.end_timestep(f"turn_{turn}")
|
|
683
|
+
# Update per-episode progress bar once per turn
|
|
684
|
+
if pbar is not None:
|
|
685
|
+
current_achievements = sum(
|
|
686
|
+
1 for v in obs.get("achievements_status", {}).values() if v
|
|
687
|
+
)
|
|
688
|
+
pbar.set_postfix({"ach": current_achievements})
|
|
689
|
+
pbar.update(1)
|
|
690
|
+
|
|
691
|
+
except Exception as e:
|
|
692
|
+
print(f" ❌ Environment step error: {e}")
|
|
693
|
+
done = True
|
|
694
|
+
|
|
695
|
+
# Progress bar updated per turn above
|
|
696
|
+
|
|
697
|
+
# Calculate invalid action rate
|
|
698
|
+
invalid_rate = invalid_actions / total_actions if total_actions > 0 else 0
|
|
699
|
+
|
|
700
|
+
# Calculate achievements
|
|
701
|
+
final_achievements = obs.get("achievements_status", {})
|
|
702
|
+
total_achievements = sum(1 for v in final_achievements.values() if v)
|
|
703
|
+
|
|
704
|
+
# Terminate environment
|
|
705
|
+
try:
|
|
706
|
+
await retry_http_request(
|
|
707
|
+
client,
|
|
708
|
+
"POST",
|
|
709
|
+
f"{config.crafter_service_url}/env/CrafterClassic/terminate",
|
|
710
|
+
json={"env_id": instance_id},
|
|
711
|
+
)
|
|
712
|
+
except Exception as e:
|
|
713
|
+
print(f" ⚠️ Failed to terminate environment: {e}")
|
|
714
|
+
|
|
715
|
+
# End session
|
|
716
|
+
await session_tracer.end_session(save=config.save_traces)
|
|
717
|
+
# Close the tracer for this episode
|
|
718
|
+
await session_tracer.close()
|
|
719
|
+
|
|
720
|
+
return {
|
|
721
|
+
"model": model_name,
|
|
722
|
+
"episode": episode_num,
|
|
723
|
+
"total_achievements": total_achievements,
|
|
724
|
+
"achievements": final_achievements,
|
|
725
|
+
"invalid_action_rate": invalid_rate,
|
|
726
|
+
"total_actions": total_actions,
|
|
727
|
+
"invalid_actions": invalid_actions,
|
|
728
|
+
"session_id": session_id,
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
except Exception as e:
|
|
732
|
+
print(f" ❌ Episode failed: {e}")
|
|
733
|
+
import traceback
|
|
734
|
+
|
|
735
|
+
traceback.print_exc()
|
|
736
|
+
|
|
737
|
+
# End session even if failed
|
|
738
|
+
await session_tracer.end_session(save=config.save_traces)
|
|
739
|
+
# Close the tracer for this episode
|
|
740
|
+
await session_tracer.close()
|
|
741
|
+
|
|
742
|
+
return {
|
|
743
|
+
"model": model_name,
|
|
744
|
+
"episode": episode_num,
|
|
745
|
+
"total_achievements": 0,
|
|
746
|
+
"achievements": {},
|
|
747
|
+
"invalid_action_rate": 1.0,
|
|
748
|
+
"total_actions": 0,
|
|
749
|
+
"invalid_actions": 0,
|
|
750
|
+
"session_id": session_id,
|
|
751
|
+
"error": str(e),
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
async def run_model_experiment(
|
|
756
|
+
config: ExperimentConfig, model_name: str, experiment_id: str, position_base: int = 0
|
|
757
|
+
) -> list[dict[str, Any]]:
|
|
758
|
+
"""Run multiple episodes for a single model in parallel with per-episode stacked progress bars."""
|
|
759
|
+
# print(f"\nRunning {config.num_episodes} episodes for {model_name} in parallel...\n")
|
|
760
|
+
|
|
761
|
+
# One progress bar per episode, stacked
|
|
762
|
+
episode_bars = [
|
|
763
|
+
tqdm(
|
|
764
|
+
total=config.max_turns,
|
|
765
|
+
desc=f"{model_name} | ep{i + 1}",
|
|
766
|
+
unit="turn",
|
|
767
|
+
leave=True,
|
|
768
|
+
position=position_base + i,
|
|
769
|
+
)
|
|
770
|
+
for i in range(config.num_episodes)
|
|
771
|
+
]
|
|
772
|
+
|
|
773
|
+
try:
|
|
774
|
+
# Create tasks for all episodes (each will create its own tracer) with concurrency limit
|
|
775
|
+
sem = asyncio.Semaphore(max(1, int(config.concurrency)))
|
|
776
|
+
|
|
777
|
+
async def _limited_run(ep_idx: int):
|
|
778
|
+
async with sem:
|
|
779
|
+
pbar = episode_bars[ep_idx]
|
|
780
|
+
try:
|
|
781
|
+
return await run_episode(config, model_name, ep_idx, experiment_id, pbar)
|
|
782
|
+
finally:
|
|
783
|
+
pbar.close()
|
|
784
|
+
|
|
785
|
+
tasks = [_limited_run(i) for i in range(config.num_episodes)]
|
|
786
|
+
|
|
787
|
+
# Run all episodes in parallel
|
|
788
|
+
results = await asyncio.gather(*tasks)
|
|
789
|
+
|
|
790
|
+
# Optional summary on the last bar
|
|
791
|
+
successful_results = [r for r in results if "error" not in r]
|
|
792
|
+
if successful_results and episode_bars:
|
|
793
|
+
avg_ach = sum(r["total_achievements"] for r in successful_results) / len(
|
|
794
|
+
successful_results
|
|
795
|
+
)
|
|
796
|
+
avg_inv = sum(r["invalid_action_rate"] for r in successful_results) / len(
|
|
797
|
+
successful_results
|
|
798
|
+
)
|
|
799
|
+
with contextlib.suppress(Exception):
|
|
800
|
+
episode_bars[-1].set_postfix(
|
|
801
|
+
{"avg_ach": f"{avg_ach:.1f}", "inv_rate": f"{avg_inv:.1%}"}
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
return results
|
|
805
|
+
finally:
|
|
806
|
+
for b in episode_bars:
|
|
807
|
+
with contextlib.suppress(Exception):
|
|
808
|
+
b.close()
|
|
809
|
+
|
|
810
|
+
|
|
811
|
+
async def analyze_results(config: ExperimentConfig, all_results: dict[str, list[dict[str, Any]]]):
|
|
812
|
+
"""Analyze results across all models using v3 database."""
|
|
813
|
+
print("\n📊 Analysis Results:")
|
|
814
|
+
print("=" * 80)
|
|
815
|
+
|
|
816
|
+
# Initialize database manager
|
|
817
|
+
db_manager = AsyncSQLTraceManager(config.database_url)
|
|
818
|
+
await db_manager.initialize()
|
|
819
|
+
|
|
820
|
+
try:
|
|
821
|
+
# Basic statistics by model
|
|
822
|
+
model_stats = {}
|
|
823
|
+
for model, results in all_results.items():
|
|
824
|
+
valid_results = [r for r in results if "error" not in r]
|
|
825
|
+
if valid_results:
|
|
826
|
+
achievements = [r["total_achievements"] for r in valid_results]
|
|
827
|
+
invalid_rates = [r["invalid_action_rate"] for r in valid_results]
|
|
828
|
+
|
|
829
|
+
model_stats[model] = {
|
|
830
|
+
"avg_achievements": np.mean(achievements),
|
|
831
|
+
"std_achievements": np.std(achievements),
|
|
832
|
+
"max_achievements": max(achievements),
|
|
833
|
+
"avg_invalid_rate": np.mean(invalid_rates),
|
|
834
|
+
"success_rate": len(valid_results) / len(results),
|
|
835
|
+
}
|
|
836
|
+
|
|
837
|
+
# Print model comparison
|
|
838
|
+
print("\n📈 Model Performance Summary:")
|
|
839
|
+
print(
|
|
840
|
+
f"{'Model':<20} {'Avg Achievements':<18} {'Max Achievements':<18} {'Invalid Rate':<15} {'Success Rate':<15}"
|
|
841
|
+
)
|
|
842
|
+
print("-" * 86)
|
|
843
|
+
|
|
844
|
+
for model, stats in sorted(
|
|
845
|
+
model_stats.items(), key=lambda x: x[1]["avg_achievements"], reverse=True
|
|
846
|
+
):
|
|
847
|
+
print(
|
|
848
|
+
f"{model:<20} {stats['avg_achievements']:>6.2f} ± {stats['std_achievements']:>4.2f} "
|
|
849
|
+
f"{stats['max_achievements']:>16} {stats['avg_invalid_rate']:>12.2%} {stats['success_rate']:>12.2%}"
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
# Achievement frequency analysis
|
|
853
|
+
print("\n🏆 Achievement Frequencies:")
|
|
854
|
+
achievement_counts = defaultdict(lambda: defaultdict(int))
|
|
855
|
+
|
|
856
|
+
for model, results in all_results.items():
|
|
857
|
+
for result in results:
|
|
858
|
+
if "error" not in result:
|
|
859
|
+
for achievement, unlocked in result["achievements"].items():
|
|
860
|
+
if unlocked:
|
|
861
|
+
achievement_counts[model][achievement] += 1
|
|
862
|
+
|
|
863
|
+
# Get all unique achievements
|
|
864
|
+
all_achievements = set()
|
|
865
|
+
for model_achievements in achievement_counts.values():
|
|
866
|
+
all_achievements.update(model_achievements.keys())
|
|
867
|
+
|
|
868
|
+
# Print achievement table
|
|
869
|
+
if all_achievements:
|
|
870
|
+
print(
|
|
871
|
+
f"\n{'Achievement':<25} "
|
|
872
|
+
+ " ".join(f"{model[:8]:>10}" for model in sorted(all_results.keys()))
|
|
873
|
+
)
|
|
874
|
+
print("-" * (25 + 11 * len(all_results)))
|
|
875
|
+
|
|
876
|
+
for achievement in sorted(all_achievements):
|
|
877
|
+
row = f"{achievement:<25}"
|
|
878
|
+
for model in sorted(all_results.keys()):
|
|
879
|
+
count = achievement_counts[model].get(achievement, 0)
|
|
880
|
+
total = len([r for r in all_results[model] if "error" not in r])
|
|
881
|
+
pct = (count / total * 100) if total > 0 else 0
|
|
882
|
+
row += f" {count:>3}/{total:<3} ({pct:>3.0f}%)"
|
|
883
|
+
print(row)
|
|
884
|
+
|
|
885
|
+
# Query model usage from database - filter to only show models used in this experiment
|
|
886
|
+
print("\n💰 Model Usage Statistics from Current Experiment:")
|
|
887
|
+
model_usage_df = await db_manager.get_model_usage()
|
|
888
|
+
|
|
889
|
+
if model_usage_df is not None and not model_usage_df.empty:
|
|
890
|
+
# Filter to only show models from this experiment
|
|
891
|
+
experiment_models = set(all_results.keys())
|
|
892
|
+
filtered_df = model_usage_df[model_usage_df["model_name"].isin(experiment_models)]
|
|
893
|
+
|
|
894
|
+
if not filtered_df.empty:
|
|
895
|
+
# Format model usage statistics as table
|
|
896
|
+
print(
|
|
897
|
+
f"{'Model':<20} {'Provider':<10} {'Usage Count':<12} {'Avg Latency (ms)':<18} {'Total Cost':<12}"
|
|
898
|
+
)
|
|
899
|
+
print("-" * 72)
|
|
900
|
+
for _, row in filtered_df.iterrows():
|
|
901
|
+
avg_latency = row["avg_latency_ms"]
|
|
902
|
+
if pd.notna(avg_latency):
|
|
903
|
+
print(
|
|
904
|
+
f"{row['model_name']:<20} {row['provider'] or 'N/A':<10} {row['usage_count']:<12} "
|
|
905
|
+
f"{avg_latency:<18.2f} ${row['total_cost_usd']:<11.4f}"
|
|
906
|
+
)
|
|
907
|
+
else:
|
|
908
|
+
print(
|
|
909
|
+
f"{row['model_name']:<20} {row['provider'] or 'N/A':<10} {row['usage_count']:<12} "
|
|
910
|
+
f"{'N/A':<18} ${row['total_cost_usd']:<11.4f}"
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
# Export detailed results under a temp/ directory (git-ignored)
|
|
914
|
+
import os
|
|
915
|
+
from pathlib import Path
|
|
916
|
+
|
|
917
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
918
|
+
out_dir = Path(os.getenv("SYNTH_OUTPUT_DIR", "temp")).resolve()
|
|
919
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
920
|
+
results_path = out_dir / f"crafter_experiment_results_{timestamp}.json"
|
|
921
|
+
|
|
922
|
+
with open(results_path, "w") as f:
|
|
923
|
+
json.dump(
|
|
924
|
+
{
|
|
925
|
+
"config": {
|
|
926
|
+
"num_episodes": config.num_episodes,
|
|
927
|
+
"max_turns": config.max_turns,
|
|
928
|
+
"difficulty": config.difficulty,
|
|
929
|
+
"models": list(all_results.keys()),
|
|
930
|
+
},
|
|
931
|
+
"results": all_results,
|
|
932
|
+
"statistics": model_stats,
|
|
933
|
+
"timestamp": timestamp,
|
|
934
|
+
},
|
|
935
|
+
f,
|
|
936
|
+
indent=2,
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
print(f"\n💾 Detailed results saved to: {results_path}")
|
|
940
|
+
|
|
941
|
+
finally:
|
|
942
|
+
await db_manager.close()
|
|
943
|
+
|
|
944
|
+
|
|
945
|
+
async def main():
|
|
946
|
+
"""Main entry point for the experiment."""
|
|
947
|
+
parser = argparse.ArgumentParser(description="Run Crafter experiments with multiple models")
|
|
948
|
+
parser.add_argument("--episodes", type=int, default=5, help="Number of episodes per model")
|
|
949
|
+
parser.add_argument("--max-turns", type=int, default=100, help="Maximum turns per episode")
|
|
950
|
+
parser.add_argument(
|
|
951
|
+
"--difficulty", choices=["easy", "medium", "hard"], default="easy", help="Game difficulty"
|
|
952
|
+
)
|
|
953
|
+
parser.add_argument("--models", nargs="+", default=MODELS_TO_TEST, help="Models to test")
|
|
954
|
+
parser.add_argument("--no-save", action="store_true", help="Don't save traces to database")
|
|
955
|
+
parser.add_argument("--quiet", action="store_true", help="Reduce output verbosity")
|
|
956
|
+
parser.add_argument("--db-url", default=DATABASE_URL, help="Database URL for tracing")
|
|
957
|
+
parser.add_argument(
|
|
958
|
+
"--concurrency", type=int, default=5, help="Max concurrent rollouts per model"
|
|
959
|
+
)
|
|
960
|
+
parser.add_argument(
|
|
961
|
+
"--base-seed",
|
|
962
|
+
type=int,
|
|
963
|
+
default=1000,
|
|
964
|
+
help="Base seed for episodes (episodes use base_seed+episode_num)",
|
|
965
|
+
)
|
|
966
|
+
parser.add_argument(
|
|
967
|
+
"--turn-timeout", type=float, default=30.0, help="Timeout per turn in seconds"
|
|
968
|
+
)
|
|
969
|
+
parser.add_argument(
|
|
970
|
+
"--episode-timeout", type=float, default=300.0, help="Total timeout per episode in seconds"
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
args = parser.parse_args()
|
|
974
|
+
|
|
975
|
+
# Create configuration
|
|
976
|
+
config = ExperimentConfig()
|
|
977
|
+
config.num_episodes = args.episodes
|
|
978
|
+
config.max_turns = args.max_turns
|
|
979
|
+
config.difficulty = args.difficulty
|
|
980
|
+
config.save_traces = not args.no_save
|
|
981
|
+
config.verbose = not args.quiet
|
|
982
|
+
config.quiet = args.quiet
|
|
983
|
+
config.database_url = args.db_url
|
|
984
|
+
config.base_seed = args.base_seed
|
|
985
|
+
config.turn_timeout = args.turn_timeout
|
|
986
|
+
config.episode_timeout = args.episode_timeout
|
|
987
|
+
config.concurrency = max(1, int(args.concurrency))
|
|
988
|
+
|
|
989
|
+
# Generate experiment ID
|
|
990
|
+
experiment_id = f"crafter_multi_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
991
|
+
|
|
992
|
+
print("🎮 Crafter Multi-Model Experiment")
|
|
993
|
+
print("=" * 50)
|
|
994
|
+
print(f"Experiment ID: {experiment_id}")
|
|
995
|
+
print(f"Models: {', '.join(args.models)}")
|
|
996
|
+
print(f"Episodes per model: {config.num_episodes}")
|
|
997
|
+
print(f"Max turns per episode: {config.max_turns}")
|
|
998
|
+
print(f"Difficulty: {config.difficulty}")
|
|
999
|
+
print(f"Seeds: {config.base_seed} to {config.base_seed + config.num_episodes - 1}")
|
|
1000
|
+
print(f"Turn timeout: {config.turn_timeout}s")
|
|
1001
|
+
print(f"Episode timeout: {config.episode_timeout}s")
|
|
1002
|
+
print(f"Save traces: {config.save_traces}")
|
|
1003
|
+
print(f"Database URL: {config.database_url}")
|
|
1004
|
+
print("=" * 50)
|
|
1005
|
+
|
|
1006
|
+
# Check Crafter service
|
|
1007
|
+
try:
|
|
1008
|
+
async with httpx.AsyncClient() as client:
|
|
1009
|
+
response = await client.get(f"{config.crafter_service_url}/health", timeout=5.0)
|
|
1010
|
+
if response.status_code != 200:
|
|
1011
|
+
print(f"❌ Crafter service not healthy at {config.crafter_service_url}")
|
|
1012
|
+
return
|
|
1013
|
+
except Exception as e:
|
|
1014
|
+
print(f"❌ Cannot connect to Crafter service at {config.crafter_service_url}: {e}")
|
|
1015
|
+
print("Please ensure the Crafter service is running.")
|
|
1016
|
+
return
|
|
1017
|
+
|
|
1018
|
+
print("✅ Crafter service is running")
|
|
1019
|
+
|
|
1020
|
+
# Run experiments for each model in parallel with stacked per-episode progress bars
|
|
1021
|
+
all_results = {}
|
|
1022
|
+
model_tasks = []
|
|
1023
|
+
for idx, model in enumerate(args.models):
|
|
1024
|
+
base = idx * (config.num_episodes + 1)
|
|
1025
|
+
model_tasks.append(run_model_experiment(config, model, experiment_id, position_base=base))
|
|
1026
|
+
results_list = await asyncio.gather(*model_tasks)
|
|
1027
|
+
for model, results in zip(args.models, results_list, strict=False):
|
|
1028
|
+
all_results[model] = results
|
|
1029
|
+
|
|
1030
|
+
# Analyze and compare results
|
|
1031
|
+
await analyze_results(config, all_results)
|
|
1032
|
+
|
|
1033
|
+
print("\n✅ Experiment complete!")
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
if __name__ == "__main__":
|
|
1037
|
+
asyncio.run(main())
|