synth-ai 0.2.2.dev0__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/cli/__init__.py +66 -0
- synth_ai/cli/balance.py +205 -0
- synth_ai/cli/calc.py +70 -0
- synth_ai/cli/demo.py +74 -0
- synth_ai/{cli.py → cli/legacy_root_backup.py} +60 -15
- synth_ai/cli/man.py +103 -0
- synth_ai/cli/recent.py +126 -0
- synth_ai/cli/root.py +184 -0
- synth_ai/cli/status.py +126 -0
- synth_ai/cli/traces.py +136 -0
- synth_ai/cli/watch.py +508 -0
- synth_ai/config/base_url.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +252 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_duckdb_v2_backup.py +413 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +646 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_synth.py +34 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth.py +1740 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth_v2_backup.py +1318 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_duckdb_v2_backup.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v2_backup.py +1352 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/test_crafter_react_agent_openai_v2_backup.py +2551 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/old/traces/session_crafter_episode_16_15227b68-2906-416f-acc4-d6a9b4fa5828_20250725_001154.json +1363 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +3 -3
- synth_ai/environments/examples/enron/dataset/corbt___enron_emails_sample_questions/default/0.0.0/293c9fe8170037e01cc9cf5834e0cd5ef6f1a6bb/dataset_info.json +1 -0
- synth_ai/environments/examples/nethack/helpers/achievements.json +64 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +1 -1
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +5 -5
- synth_ai/environments/examples/red/units/test_movement_debug.py +2 -2
- synth_ai/environments/examples/red/units/test_retry_movement.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/available_envs.json +122 -0
- synth_ai/environments/examples/sokoban/verified_puzzles.json +54987 -0
- synth_ai/experimental/synth_oss.py +446 -0
- synth_ai/learning/core.py +21 -0
- synth_ai/learning/gateway.py +4 -0
- synth_ai/learning/prompts/mipro.py +0 -0
- synth_ai/lm/__init__.py +3 -0
- synth_ai/lm/core/main.py +4 -0
- synth_ai/lm/core/main_v3.py +68 -13
- synth_ai/lm/core/vendor_clients.py +4 -0
- synth_ai/lm/provider_support/openai.py +11 -2
- synth_ai/lm/vendors/base.py +7 -0
- synth_ai/lm/vendors/openai_standard.py +339 -4
- synth_ai/lm/vendors/openai_standard_responses.py +243 -0
- synth_ai/lm/vendors/synth_client.py +155 -5
- synth_ai/lm/warmup.py +54 -17
- synth_ai/tracing/__init__.py +18 -0
- synth_ai/tracing_v1/__init__.py +29 -14
- synth_ai/tracing_v3/config.py +13 -7
- synth_ai/tracing_v3/db_config.py +6 -6
- synth_ai/tracing_v3/turso/manager.py +8 -8
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/dashboard.py +329 -0
- synth_ai/v0/tracing/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/base_client.py +3 -3
- synth_ai/{tracing → v0/tracing}/client_manager.py +1 -1
- synth_ai/{tracing → v0/tracing}/context.py +1 -1
- synth_ai/{tracing → v0/tracing}/decorators.py +11 -11
- synth_ai/v0/tracing/events/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/events/manage.py +4 -4
- synth_ai/{tracing → v0/tracing}/events/scope.py +6 -6
- synth_ai/{tracing → v0/tracing}/events/store.py +3 -3
- synth_ai/{tracing → v0/tracing}/immediate_client.py +6 -6
- synth_ai/{tracing → v0/tracing}/log_client_base.py +2 -2
- synth_ai/{tracing → v0/tracing}/retry_queue.py +3 -3
- synth_ai/{tracing → v0/tracing}/trackers.py +2 -2
- synth_ai/{tracing → v0/tracing}/upload.py +4 -4
- synth_ai/v0/tracing_v1/__init__.py +16 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/base_client.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/client_manager.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/context.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/decorators.py +11 -11
- synth_ai/v0/tracing_v1/events/__init__.py +0 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/manage.py +4 -4
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/scope.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/store.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/immediate_client.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/log_client_base.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/retry_queue.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/trackers.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/upload.py +4 -4
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/METADATA +98 -4
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/RECORD +98 -62
- /synth_ai/{tracing/events/__init__.py → environments/examples/crafter_classic/debug_translation.py} +0 -0
- /synth_ai/{tracing_v1/events/__init__.py → learning/prompts/gepa.py} +0 -0
- /synth_ai/{tracing → v0/tracing}/abstractions.py +0 -0
- /synth_ai/{tracing → v0/tracing}/config.py +0 -0
- /synth_ai/{tracing → v0/tracing}/local.py +0 -0
- /synth_ai/{tracing → v0/tracing}/utils.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/abstractions.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/config.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/local.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/utils.py +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1740 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Test script to run ReAct agents against Crafter environment using LM class with Synth backend.
|
4
|
+
This demonstrates using the LM class with Synth models through native integration.
|
5
|
+
|
6
|
+
This version uses the new tracing_v3 system with async Turso/SQLite backend.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import logging
|
10
|
+
# Disable httpx logging immediately
|
11
|
+
logging.getLogger("httpx").setLevel(logging.ERROR)
|
12
|
+
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
13
|
+
|
14
|
+
import asyncio
|
15
|
+
import json
|
16
|
+
import uuid
|
17
|
+
import math
|
18
|
+
import argparse
|
19
|
+
import toml
|
20
|
+
import yaml
|
21
|
+
import time
|
22
|
+
import functools
|
23
|
+
from datetime import datetime
|
24
|
+
from typing import Dict, Any, Optional, List, Set, Literal
|
25
|
+
from pydantic import BaseModel, Field
|
26
|
+
from httpx import AsyncClient
|
27
|
+
import httpx
|
28
|
+
import sys
|
29
|
+
import os
|
30
|
+
from pathlib import Path
|
31
|
+
from tqdm import tqdm
|
32
|
+
import random
|
33
|
+
import glob
|
34
|
+
from collections import defaultdict
|
35
|
+
|
36
|
+
# Configure logging to suppress noisy third-party logs when in quiet mode
|
37
|
+
def setup_logging(quiet_mode: bool = False):
|
38
|
+
"""Setup logging configuration."""
|
39
|
+
if quiet_mode:
|
40
|
+
# Suppress most third-party logging in quiet mode
|
41
|
+
logging.getLogger("httpx").setLevel(logging.ERROR)
|
42
|
+
logging.getLogger("synth_ai.tracing_v3").setLevel(logging.ERROR)
|
43
|
+
logging.getLogger("synth_ai.tracing_v3.turso").setLevel(logging.ERROR)
|
44
|
+
logging.getLogger("sqlalchemy").setLevel(logging.ERROR)
|
45
|
+
logging.getLogger("aiosqlite").setLevel(logging.ERROR)
|
46
|
+
# Suppress httpcore as well (used by httpx)
|
47
|
+
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
48
|
+
else:
|
49
|
+
# Normal logging levels
|
50
|
+
logging.getLogger("httpx").setLevel(logging.ERROR) # Always suppress httpx logs
|
51
|
+
logging.getLogger("synth_ai.tracing_v3").setLevel(logging.INFO)
|
52
|
+
|
53
|
+
# Set default logging to avoid noisy logs during import
|
54
|
+
setup_logging(quiet_mode=True)
|
55
|
+
|
56
|
+
# Setup environment
|
57
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent.parent))
|
58
|
+
|
59
|
+
# Disable v1 logging to see v3 tracing clearly
|
60
|
+
os.environ["LANGFUSE_ENABLED"] = "false"
|
61
|
+
os.environ["SYNTH_LOGGING"] = "false"
|
62
|
+
|
63
|
+
import numpy as np
|
64
|
+
import itertools
|
65
|
+
|
66
|
+
# Import Synth warmup utilities
|
67
|
+
from synth_ai.lm.warmup import warmup_synth_model
|
68
|
+
from synth_ai.lm.config import SynthConfig
|
69
|
+
|
70
|
+
# Import session tracer for v3 tracing
|
71
|
+
from synth_ai.tracing_v3 import SessionTracer
|
72
|
+
from synth_ai.tracing_v3.abstractions import (
|
73
|
+
SessionEventMessage, TimeRecord,
|
74
|
+
RuntimeEvent, EnvironmentEvent, LMCAISEvent
|
75
|
+
)
|
76
|
+
# create_experiment_context will be defined as a helper function below
|
77
|
+
from synth_ai.tracing_v3.turso.manager import AsyncSQLTraceManager
|
78
|
+
from synth_ai.tracing_v3.turso.daemon import SqldDaemon
|
79
|
+
|
80
|
+
# Import Crafter hooks for v3
|
81
|
+
from synth_ai.tracing_v3.hooks import HookManager
|
82
|
+
try:
|
83
|
+
from synth_ai.environments.examples.crafter_classic.trace_hooks_v3 import CRAFTER_HOOKS_V3
|
84
|
+
# Create a custom hook manager without default print statements
|
85
|
+
QUIET_HOOKS = HookManager()
|
86
|
+
# Don't add any hooks for now to keep output clean
|
87
|
+
except ImportError:
|
88
|
+
QUIET_HOOKS = HookManager()
|
89
|
+
|
90
|
+
# Import LM components (v3 version if available)
|
91
|
+
try:
|
92
|
+
from synth_ai.lm.core.main_v3 import LM
|
93
|
+
except ImportError:
|
94
|
+
from synth_ai.lm.core.main_v2 import LM
|
95
|
+
|
96
|
+
# Configuration constants
|
97
|
+
HTTP_TIMEOUT = 30.0 # Increased from 10.0 for better handling of concurrent load and LM response times
|
98
|
+
MAX_RETRIES = 3
|
99
|
+
RETRY_DELAY = 1.0
|
100
|
+
|
101
|
+
# Use the backend
|
102
|
+
|
103
|
+
|
104
|
+
async def create_experiment_context(db_manager: AsyncSQLTraceManager, experiment_name: str, description: str) -> Dict[str, Any]:
|
105
|
+
"""Create an experiment context for v3 tracing."""
|
106
|
+
experiment_id = f"exp_{uuid.uuid4().hex[:12]}"
|
107
|
+
await db_manager.create_experiment(
|
108
|
+
experiment_id=experiment_id,
|
109
|
+
name=experiment_name,
|
110
|
+
description=description,
|
111
|
+
configuration={}
|
112
|
+
)
|
113
|
+
return {
|
114
|
+
'experiment_id': experiment_id,
|
115
|
+
'experiment_name': experiment_name,
|
116
|
+
'description': description
|
117
|
+
}
|
118
|
+
|
119
|
+
|
120
|
+
def cleanup_old_files():
|
121
|
+
"""Clean up old trace files and result files to keep directory clean."""
|
122
|
+
# Remove old JSON result files (keep only the latest 5)
|
123
|
+
result_files = glob.glob("crafter_lm_synth_results_*.json")
|
124
|
+
if len(result_files) > 5:
|
125
|
+
# Sort by modification time and keep only the latest 5
|
126
|
+
result_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
127
|
+
for old_file in result_files[5:]:
|
128
|
+
try:
|
129
|
+
os.remove(old_file)
|
130
|
+
print(f"🗑️ Cleaned up old result file: {old_file}")
|
131
|
+
except OSError:
|
132
|
+
pass
|
133
|
+
|
134
|
+
|
135
|
+
def _load_env_from_monorepo() -> dict:
|
136
|
+
"""Load environment variables from monorepo/.env.local if present."""
|
137
|
+
env_file = Path(__file__).resolve().parent.parent.parent.parent.parent.parent / "monorepo/.env.local"
|
138
|
+
env_vars = {}
|
139
|
+
|
140
|
+
if env_file.exists():
|
141
|
+
with open(env_file, "r") as f:
|
142
|
+
for line in f:
|
143
|
+
line = line.strip()
|
144
|
+
if line and not line.startswith('#') and '=' in line:
|
145
|
+
key, value = line.split('=', 1)
|
146
|
+
# Remove quotes if present
|
147
|
+
value = value.strip().strip('"').strip("'")
|
148
|
+
env_vars[key] = value
|
149
|
+
|
150
|
+
return env_vars
|
151
|
+
|
152
|
+
|
153
|
+
def _load_testing_yaml_api_key() -> Optional[str]:
|
154
|
+
"""Load SYNTH_API_KEY from monorepo/tests/prod/testing_info.yaml if present."""
|
155
|
+
# First try the new env vars from monorepo/.env.local
|
156
|
+
env_vars = _load_env_from_monorepo()
|
157
|
+
|
158
|
+
# Try production key first, then test key
|
159
|
+
if "SYNTH_API_KEY_PROD" in env_vars:
|
160
|
+
return env_vars["SYNTH_API_KEY_PROD"]
|
161
|
+
elif "SYNTH_API_KEY_TEST" in env_vars:
|
162
|
+
return env_vars["SYNTH_API_KEY_TEST"]
|
163
|
+
|
164
|
+
# Fallback to the old YAML method
|
165
|
+
yaml_path = Path(__file__).resolve().parent.parent.parent.parent.parent.parent / "monorepo/tests/prod/testing_info.yaml"
|
166
|
+
if yaml_path.exists():
|
167
|
+
with open(yaml_path, "r") as f:
|
168
|
+
data = yaml.safe_load(f)
|
169
|
+
return data.get("SYNTH_API_KEY")
|
170
|
+
return None
|
171
|
+
|
172
|
+
|
173
|
+
def setup_synth_environment():
|
174
|
+
"""Setup environment variables for Synth/Modal endpoints.
|
175
|
+
|
176
|
+
Resolution order for the base URL:
|
177
|
+
1. Explicit environment variables (SYNTH_BASE_URL or MODAL_BASE_URL)
|
178
|
+
2. PROD_API_URL env var used in production integration tests
|
179
|
+
3. Hard-coded production constant (https://agent-learning.onrender.com)
|
180
|
+
|
181
|
+
The API key is resolved from the matching *_API_KEY env vars or, if not
|
182
|
+
present, from the shared testing_info.yaml used by the prod tests.
|
183
|
+
"""
|
184
|
+
# Load environment variables from monorepo/.env.local
|
185
|
+
env_vars = _load_env_from_monorepo()
|
186
|
+
|
187
|
+
synth_base_url = (
|
188
|
+
os.getenv("SYNTH_BASE_URL")
|
189
|
+
or os.getenv("MODAL_BASE_URL")
|
190
|
+
or os.getenv("PROD_API_URL")
|
191
|
+
or env_vars.get("SYNTH_BASE_URL_PROD") # Use production URL from .env.local
|
192
|
+
or "https://agent-learning.onrender.com/api"
|
193
|
+
)
|
194
|
+
|
195
|
+
synth_api_key = (
|
196
|
+
os.getenv("SYNTH_API_KEY")
|
197
|
+
or _load_testing_yaml_api_key()
|
198
|
+
)
|
199
|
+
|
200
|
+
# # --- Validate API key format ---
|
201
|
+
# if synth_api_key:
|
202
|
+
# VALID_PREFIXES = ("sk-", "sk_live_", "sk_test_")
|
203
|
+
# if not any(synth_api_key.startswith(p) for p in VALID_PREFIXES):
|
204
|
+
# truncated = synth_api_key[:8] if len(synth_api_key) >= 8 else synth_api_key
|
205
|
+
# expected_formats = " or ".join(VALID_PREFIXES)
|
206
|
+
# raise ValueError(
|
207
|
+
# f"Invalid API key format. Expected prefix {expected_formats}. Provided key begins with '{truncated}'."
|
208
|
+
# )
|
209
|
+
# else:
|
210
|
+
# raise ValueError(
|
211
|
+
# "SYNTH_API_KEY or MODAL_API_KEY must be provided via environment variables or testing_info.yaml"
|
212
|
+
# )
|
213
|
+
|
214
|
+
# Ensure trailing /v1 for OpenAI-compatible endpoints
|
215
|
+
if not synth_base_url.endswith("/v1"):
|
216
|
+
synth_base_url = synth_base_url.rstrip("/") + "/v1"
|
217
|
+
synth_base_url = synth_base_url.rstrip("/")
|
218
|
+
|
219
|
+
# Propagate to OpenAI SDK env vars expected by LM class
|
220
|
+
os.environ["OPENAI_API_BASE"] = synth_base_url
|
221
|
+
os.environ["OPENAI_BASE_URL"] = synth_base_url
|
222
|
+
os.environ["OPENAI_API_KEY"] = synth_api_key
|
223
|
+
|
224
|
+
return synth_base_url, synth_api_key
|
225
|
+
|
226
|
+
|
227
|
+
async def retry_http_request(client: AsyncClient, method: str, url: str, **kwargs) -> Any:
|
228
|
+
"""Retry HTTP requests with exponential backoff and jitter."""
|
229
|
+
last_exception = None
|
230
|
+
|
231
|
+
for attempt in range(MAX_RETRIES):
|
232
|
+
try:
|
233
|
+
if attempt > 0:
|
234
|
+
delay = min(RETRY_DELAY * (2 ** (attempt - 1)), RETRY_DELAY * 2) # Use RETRY_DELAY
|
235
|
+
jitter = random.uniform(0, 0.1 * delay)
|
236
|
+
total_delay = delay + jitter
|
237
|
+
await asyncio.sleep(total_delay)
|
238
|
+
|
239
|
+
response = await client.request(method, url, timeout=HTTP_TIMEOUT, **kwargs)
|
240
|
+
|
241
|
+
if response.status_code < 500:
|
242
|
+
return response
|
243
|
+
|
244
|
+
last_exception = Exception(f"HTTP {response.status_code}: {response.text}")
|
245
|
+
|
246
|
+
except httpx.ReadError as e:
|
247
|
+
last_exception = e
|
248
|
+
if attempt < MAX_RETRIES - 1:
|
249
|
+
read_error_delay = min(1.0 * (2 ** attempt), 5.0)
|
250
|
+
await asyncio.sleep(read_error_delay)
|
251
|
+
except Exception as e:
|
252
|
+
last_exception = e
|
253
|
+
|
254
|
+
print(f" ❌ HTTP request failed after {MAX_RETRIES} attempts: {type(last_exception).__name__}: {str(last_exception)[:200]}")
|
255
|
+
raise last_exception
|
256
|
+
|
257
|
+
|
258
|
+
def create_message(content: Any, message_type: str, origin_system_id: Any, turn: int) -> SessionEventMessage:
|
259
|
+
"""Create a message with origin system ID embedded in content."""
|
260
|
+
# Map custom message types to valid v3 message types
|
261
|
+
type_mapping = {
|
262
|
+
"observation": "system", # Map observation to system message
|
263
|
+
"user": "user",
|
264
|
+
"assistant": "assistant",
|
265
|
+
"system": "system",
|
266
|
+
"tool_use": "tool_use",
|
267
|
+
"tool_result": "tool_result"
|
268
|
+
}
|
269
|
+
|
270
|
+
return SessionEventMessage(
|
271
|
+
content=json.dumps({
|
272
|
+
"origin_system_id": str(origin_system_id),
|
273
|
+
"payload": content
|
274
|
+
}),
|
275
|
+
message_type=type_mapping.get(message_type, "system"), # Default to system
|
276
|
+
time_record=TimeRecord(
|
277
|
+
event_time=time.time(),
|
278
|
+
message_time=turn
|
279
|
+
)
|
280
|
+
)
|
281
|
+
|
282
|
+
|
283
|
+
def compress_observation_for_trace(obs: Dict[str, Any]) -> Dict[str, Any]:
|
284
|
+
"""Compress observation for trace storage to avoid huge trace files."""
|
285
|
+
compressed = obs.copy()
|
286
|
+
|
287
|
+
# Compress semantic map if present
|
288
|
+
if "semantic_map" in compressed:
|
289
|
+
del compressed["semantic_map"]
|
290
|
+
|
291
|
+
# Compress other large fields
|
292
|
+
if "rgb" in compressed:
|
293
|
+
del compressed["rgb"]
|
294
|
+
|
295
|
+
return compressed
|
296
|
+
|
297
|
+
|
298
|
+
|
299
|
+
def format_semantic_map_view_v2(obs: Dict[str, Any], view_size: int = 7) -> str:
|
300
|
+
"""Format a semantic map view around the player with normal names using real Crafter mapping."""
|
301
|
+
# Get semantic map
|
302
|
+
semantic_map = obs.get("semantic_map")
|
303
|
+
if semantic_map is None:
|
304
|
+
return "No semantic map available"
|
305
|
+
|
306
|
+
# Convert to numpy array if needed
|
307
|
+
sem_arr = np.asarray(semantic_map)
|
308
|
+
if sem_arr.ndim == 1:
|
309
|
+
# Assuming square map, reshape
|
310
|
+
size = int(np.sqrt(sem_arr.size))
|
311
|
+
sem_arr = sem_arr.reshape(size, size)
|
312
|
+
|
313
|
+
# Get player position
|
314
|
+
player_pos = obs.get("player_position", [sem_arr.shape[0]//2, sem_arr.shape[1]//2])
|
315
|
+
px, py = int(player_pos[0]), int(player_pos[1])
|
316
|
+
|
317
|
+
# Get real crafter semantic mapping directly from crafter library
|
318
|
+
import crafter
|
319
|
+
dummyenv = crafter.Env()
|
320
|
+
try:
|
321
|
+
max_id = max(max(dummyenv._world._mat_ids.values()), max(dummyenv._sem_view._obj_ids.values())) + 1
|
322
|
+
id_to_item = ["void"] * max_id
|
323
|
+
for name, ind in itertools.chain(dummyenv._world._mat_ids.items(), dummyenv._sem_view._obj_ids.items()):
|
324
|
+
clean = name.__name__ if hasattr(name, "__name__") else (str(name) if name is not None else "none")
|
325
|
+
id_to_item[ind] = clean.lower()
|
326
|
+
finally:
|
327
|
+
try:
|
328
|
+
dummyenv.close()
|
329
|
+
except (AttributeError, Exception):
|
330
|
+
pass
|
331
|
+
|
332
|
+
# Create view
|
333
|
+
half = view_size // 2
|
334
|
+
lines = []
|
335
|
+
visible_items = set()
|
336
|
+
|
337
|
+
for dy in range(-half, half + 1):
|
338
|
+
row = []
|
339
|
+
for dx in range(-half, half + 1):
|
340
|
+
x, y = px + dx, py + dy
|
341
|
+
|
342
|
+
if dx == 0 and dy == 0:
|
343
|
+
row.append('you') # Player
|
344
|
+
elif 0 <= x < sem_arr.shape[0] and 0 <= y < sem_arr.shape[1]:
|
345
|
+
val = int(sem_arr[x, y])
|
346
|
+
# Use the real crafter mapping
|
347
|
+
item_name = id_to_item[val] if val < len(id_to_item) else f"unknown_{val}"
|
348
|
+
row.append(item_name)
|
349
|
+
if item_name not in ['grass', 'you', 'void']:
|
350
|
+
visible_items.add(item_name)
|
351
|
+
else:
|
352
|
+
row.append('void') # Out of bounds
|
353
|
+
|
354
|
+
lines.append(' '.join(row))
|
355
|
+
|
356
|
+
# Add legend of visible items
|
357
|
+
legend = f"Visible items: {', '.join(sorted(visible_items))}" if visible_items else "No special items visible (mostly grass)"
|
358
|
+
|
359
|
+
return "\n".join(lines) + "\n" + legend
|
360
|
+
|
361
|
+
|
362
|
+
def get_openai_tools():
|
363
|
+
"""Get OpenAI-compatible tool definitions for Synth models."""
|
364
|
+
return [
|
365
|
+
{
|
366
|
+
"type": "function",
|
367
|
+
"function": {
|
368
|
+
"name": "interact",
|
369
|
+
"description": "Perform actions in the Crafter environment.",
|
370
|
+
"parameters": {
|
371
|
+
"type": "object",
|
372
|
+
"properties": {
|
373
|
+
"actions": {
|
374
|
+
"type": "array",
|
375
|
+
"items": {
|
376
|
+
"type": "string"
|
377
|
+
},
|
378
|
+
"description": "List of actions to perform in sequence (e.g., ['move_right', 'move_right', 'do']). Available actions: move_left, move_right, move_up, move_down, do, sleep, place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword, noop"
|
379
|
+
},
|
380
|
+
"reasoning": {
|
381
|
+
"type": "string",
|
382
|
+
"description": "Reasoning for these actions"
|
383
|
+
}
|
384
|
+
},
|
385
|
+
"required": ["actions", "reasoning"]
|
386
|
+
}
|
387
|
+
}
|
388
|
+
},
|
389
|
+
{
|
390
|
+
"type": "function",
|
391
|
+
"function": {
|
392
|
+
"name": "terminate",
|
393
|
+
"description": "End the episode when finished or no progress can be made.",
|
394
|
+
"parameters": {
|
395
|
+
"type": "object",
|
396
|
+
"properties": {
|
397
|
+
"reason": {
|
398
|
+
"type": "string",
|
399
|
+
"description": "Reason for termination"
|
400
|
+
}
|
401
|
+
},
|
402
|
+
"required": ["reason"]
|
403
|
+
}
|
404
|
+
}
|
405
|
+
}
|
406
|
+
]
|
407
|
+
|
408
|
+
|
409
|
+
# --- Configuration Class ---
|
410
|
+
class CrafterConfig:
|
411
|
+
"""Configuration for Crafter evaluation with Synth backend."""
|
412
|
+
|
413
|
+
def __init__(self, config_path: Optional[str] = None):
|
414
|
+
# Default values
|
415
|
+
self.model_name: Optional[str] = None
|
416
|
+
self.num_instances = 1
|
417
|
+
self.max_turns = 2
|
418
|
+
self.difficulty = "easy"
|
419
|
+
self.service_base_url = "http://localhost:8901"
|
420
|
+
self.service_timeout = 30.0
|
421
|
+
self.seed = 42
|
422
|
+
self.save_traces = True
|
423
|
+
self.save_detailed_results = True
|
424
|
+
self.verbose = False
|
425
|
+
self.quiet = False # Add quiet mode support
|
426
|
+
self.analyze_traces = False
|
427
|
+
|
428
|
+
# V3 tracing settings
|
429
|
+
self.enable_v3_tracing = True
|
430
|
+
# Standardize to a single shared v3 DB by default; allow env override
|
431
|
+
self.v3_trace_dir = os.getenv("SYNTH_TRACES_ROOT", "./traces/v3")
|
432
|
+
# Use shared DB path unless explicitly overridden via env or config
|
433
|
+
self.turso_db_path = os.getenv("SQLD_DB_PATH", os.path.join(self.v3_trace_dir, "synth_ai.db"))
|
434
|
+
self.start_sqld_daemon = True # Whether to start sqld daemon
|
435
|
+
self.auto_cleanup = True # Clean up old files automatically
|
436
|
+
|
437
|
+
# Synth-specific settings
|
438
|
+
self.warmup_model = True
|
439
|
+
self.warmup_max_attempts = 30
|
440
|
+
self.warmup_timeout = 60.0 # Default timeout in seconds
|
441
|
+
self.use_synth_backend = True # Flag to indicate Synth backend
|
442
|
+
|
443
|
+
# Load from TOML if provided
|
444
|
+
if config_path and os.path.exists(config_path):
|
445
|
+
self.load_from_toml(config_path)
|
446
|
+
|
447
|
+
def load_from_toml(self, config_path: str):
|
448
|
+
"""Load configuration from TOML file."""
|
449
|
+
config = toml.load(config_path)
|
450
|
+
|
451
|
+
eval_config = config.get("eval", {})
|
452
|
+
self.model_name = eval_config.get("model_name", self.model_name)
|
453
|
+
self.num_instances = eval_config.get("episodes", self.num_instances)
|
454
|
+
self.max_turns = eval_config.get("max_steps", self.max_turns)
|
455
|
+
self.difficulty = eval_config.get("difficulty", self.difficulty)
|
456
|
+
self.seed = eval_config.get("seed", self.seed)
|
457
|
+
|
458
|
+
service_config = config.get("service", {})
|
459
|
+
self.service_base_url = service_config.get("base_url", self.service_base_url)
|
460
|
+
self.service_timeout = service_config.get("timeout", self.service_timeout)
|
461
|
+
|
462
|
+
output_config = config.get("output", {})
|
463
|
+
self.save_traces = output_config.get("save_traces", self.save_traces)
|
464
|
+
self.save_detailed_results = output_config.get(
|
465
|
+
"save_detailed_results", self.save_detailed_results
|
466
|
+
)
|
467
|
+
|
468
|
+
# V3 tracing config
|
469
|
+
tracing_config = config.get("tracing_v3", {})
|
470
|
+
self.enable_v3_tracing = tracing_config.get("enabled", self.enable_v3_tracing)
|
471
|
+
self.v3_trace_dir = tracing_config.get("trace_dir", self.v3_trace_dir)
|
472
|
+
self.turso_db_path = tracing_config.get("db_path", self.turso_db_path)
|
473
|
+
self.start_sqld_daemon = tracing_config.get("start_daemon", self.start_sqld_daemon)
|
474
|
+
self.auto_cleanup = tracing_config.get("auto_cleanup", self.auto_cleanup)
|
475
|
+
|
476
|
+
# Synth config
|
477
|
+
synth_config = config.get("synth", {})
|
478
|
+
self.warmup_model = synth_config.get("warmup_model", self.warmup_model)
|
479
|
+
self.warmup_max_attempts = synth_config.get("warmup_max_attempts", self.warmup_max_attempts)
|
480
|
+
self.warmup_timeout = synth_config.get("warmup_timeout", self.warmup_timeout)
|
481
|
+
self.use_synth_backend = synth_config.get("use_synth_backend", self.use_synth_backend)
|
482
|
+
|
483
|
+
|
484
|
+
# --- Base ReAct Agent using LM with Synth ---
|
485
|
+
class BaseReActAgentWithLMSynth:
|
486
|
+
"""Base ReAct agent using LM class configured for Synth backend."""
|
487
|
+
|
488
|
+
def __init__(self, model_name: str, max_turns: int = 20, verbose: bool = False,
|
489
|
+
tracer: Optional[SessionTracer] = None, episode_id: int = 0, quiet: bool = False,
|
490
|
+
model_params: Optional[Dict[str, Any]] = None):
|
491
|
+
self.model_name = model_name
|
492
|
+
self.max_turns = max_turns
|
493
|
+
self.verbose = verbose
|
494
|
+
self.quiet = quiet
|
495
|
+
self.history = []
|
496
|
+
self.system_name = "base-react-agent-lm-synth"
|
497
|
+
self.tools = get_openai_tools()
|
498
|
+
self.tracer = tracer
|
499
|
+
self.system_id = f"{self.system_name}_{uuid.uuid4()}"
|
500
|
+
self.episode_id = episode_id
|
501
|
+
|
502
|
+
# Default model parameters
|
503
|
+
default_model_params = {
|
504
|
+
"temperature": 0.7,
|
505
|
+
"max_tokens": 512,
|
506
|
+
"top_p": 1.0,
|
507
|
+
"frequency_penalty": 0.0,
|
508
|
+
"presence_penalty": 0.0,
|
509
|
+
"tool_choice": "auto"
|
510
|
+
}
|
511
|
+
|
512
|
+
# Merge user-provided parameters with defaults
|
513
|
+
self.model_params = {**default_model_params, **(model_params or {})}
|
514
|
+
|
515
|
+
# Setup Synth environment variables
|
516
|
+
setup_synth_environment()
|
517
|
+
|
518
|
+
# Create LM instance with synth provider and configurable parameters
|
519
|
+
self.lm = LM(
|
520
|
+
model_name=model_name,
|
521
|
+
formatting_model_name=model_name,
|
522
|
+
temperature=self.model_params["temperature"],
|
523
|
+
synth_logging=False, # Disable v1 tracing
|
524
|
+
provider="synth", # Use synth provider
|
525
|
+
session_tracer=tracer,
|
526
|
+
system_id=self.system_id,
|
527
|
+
enable_v3_tracing=True,
|
528
|
+
# Pass additional model parameters
|
529
|
+
max_tokens=self.model_params["max_tokens"],
|
530
|
+
top_p=self.model_params["top_p"],
|
531
|
+
frequency_penalty=self.model_params["frequency_penalty"],
|
532
|
+
presence_penalty=self.model_params["presence_penalty"],
|
533
|
+
# Qwen3 think mode (propagated by vendor to chat_template_kwargs)
|
534
|
+
enable_thinking=self.model_params.get("enable_thinking"),
|
535
|
+
# Forward arbitrary extra_body to vendor for features like
|
536
|
+
# stop_after_tool_calls. The runner sets this to 1.
|
537
|
+
extra_body=self.model_params.get("extra_body"),
|
538
|
+
)
|
539
|
+
|
540
|
+
# Agent state tracking
|
541
|
+
self.agent_state = {
|
542
|
+
"message_history": [],
|
543
|
+
"steps_taken": 0,
|
544
|
+
"steps_remaining": max_turns,
|
545
|
+
"total_tokens_used": 0,
|
546
|
+
"tool_calls_made": 0,
|
547
|
+
"current_turn": 0,
|
548
|
+
"last_failure": None # Track last failure for prompting
|
549
|
+
}
|
550
|
+
|
551
|
+
async def decide(self, obs: str, system_message: str, turn: int) -> Dict[str, Any]:
|
552
|
+
"""Get agent decision based on observation using LM class with Synth backend."""
|
553
|
+
# Update agent state
|
554
|
+
self.agent_state["current_turn"] = turn
|
555
|
+
self.agent_state["steps_taken"] = turn
|
556
|
+
self.agent_state["steps_remaining"] = self.max_turns - turn
|
557
|
+
|
558
|
+
# Create conversation context with unique episode ID to prevent caching
|
559
|
+
context = f"Episode {self.episode_id} - Turn {turn + 1}/{self.max_turns}\n\n{obs}"
|
560
|
+
|
561
|
+
# Build messages in OpenAI format for tools
|
562
|
+
# Augment the system message if the previous turn failed to produce a tool call
|
563
|
+
local_system_message = system_message
|
564
|
+
last_failure = self.agent_state.get("last_failure")
|
565
|
+
if last_failure:
|
566
|
+
local_system_message = (
|
567
|
+
f"{system_message}\n\nIMPORTANT: In the previous turn, no valid tool call was returned. "
|
568
|
+
f"Error: {last_failure}. You MUST respond with a single function tool call in the OpenAI tools format."
|
569
|
+
)
|
570
|
+
messages = [
|
571
|
+
{"role": "system", "content": local_system_message},
|
572
|
+
{"role": "user", "content": context}
|
573
|
+
]
|
574
|
+
|
575
|
+
# Add to message history
|
576
|
+
self.agent_state["message_history"].extend(messages)
|
577
|
+
|
578
|
+
# Truncate history if too long
|
579
|
+
max_history_length = 20
|
580
|
+
if len(self.agent_state["message_history"]) > max_history_length:
|
581
|
+
self.agent_state["message_history"] = (
|
582
|
+
[self.agent_state["message_history"][0]] +
|
583
|
+
self.agent_state["message_history"][-(max_history_length-1):]
|
584
|
+
)
|
585
|
+
|
586
|
+
try:
|
587
|
+
llm_start = time.time()
|
588
|
+
|
589
|
+
# Optionally print full prompt on final turn when verbose
|
590
|
+
if self.verbose and turn == self.max_turns - 1:
|
591
|
+
print("\n🔍 FINAL TURN PROMPT:")
|
592
|
+
print("="*80)
|
593
|
+
print(f"System: {local_system_message[:200]}...")
|
594
|
+
print(f"\nUser message:\n{context}")
|
595
|
+
print("="*80)
|
596
|
+
|
597
|
+
# Debug: Print request info only when verbose
|
598
|
+
if self.verbose:
|
599
|
+
print(f"\n🔍 DEBUG: LM call details (turn {turn})")
|
600
|
+
print(f" Model: {self.model_name}")
|
601
|
+
print(f" Provider: synth")
|
602
|
+
print(f" Messages: {len(messages)} messages")
|
603
|
+
print(f" Tools: {len(self.tools) if self.tools else 0} tools")
|
604
|
+
if self.tools:
|
605
|
+
print(f" Tool 0 name: {self.tools[0].get('function', {}).get('name', 'unknown')}")
|
606
|
+
print(f" Tools structure: {json.dumps(self.tools[0], indent=4)[:300]}...")
|
607
|
+
|
608
|
+
# Call LM with turn number for v3 tracing
|
609
|
+
# The LM class should handle Synth routing internally
|
610
|
+
if self.verbose:
|
611
|
+
print(f"🔍 DEBUG: LM sampling params => max_tokens={self.model_params.get('max_tokens')} temp={self.model_params.get('temperature')} top_p={self.model_params.get('top_p')} tool_choice={self.model_params.get('tool_choice')}")
|
612
|
+
response = await self.lm.respond_async(
|
613
|
+
messages=messages,
|
614
|
+
turn_number=turn,
|
615
|
+
# Pass tools in the format expected by LM class
|
616
|
+
tools=self.tools,
|
617
|
+
max_tokens=self.model_params["max_tokens"],
|
618
|
+
tool_choice=self.model_params.get("tool_choice", "auto"),
|
619
|
+
# Pass extra_body per call to ensure backend receives stop_after_tool_calls
|
620
|
+
extra_body=self.model_params.get("extra_body")
|
621
|
+
)
|
622
|
+
|
623
|
+
llm_end = time.time()
|
624
|
+
|
625
|
+
# Minimal output: show only tool_call presence, number of actions, and tokens
|
626
|
+
completion_tokens = None
|
627
|
+
prompt_tokens = None
|
628
|
+
toks_per_sec = None
|
629
|
+
if hasattr(response, 'usage') and isinstance(getattr(response, 'usage'), dict):
|
630
|
+
completion_tokens = response.usage.get('completion_tokens')
|
631
|
+
prompt_tokens = response.usage.get('prompt_tokens')
|
632
|
+
# Compute tokens/sec if we have duration and completion tokens
|
633
|
+
try:
|
634
|
+
if completion_tokens is not None:
|
635
|
+
duration_s = max(1e-6, (llm_end - llm_start))
|
636
|
+
toks_per_sec = round(float(completion_tokens) / duration_s, 2)
|
637
|
+
except Exception:
|
638
|
+
toks_per_sec = None
|
639
|
+
|
640
|
+
# Parse the response to extract tool calls
|
641
|
+
raw_response = response.raw_response
|
642
|
+
decision: Dict[str, Any]
|
643
|
+
|
644
|
+
if hasattr(response, 'tool_calls') and response.tool_calls:
|
645
|
+
tool_call = response.tool_calls[0]
|
646
|
+
parsed_decision = None
|
647
|
+
fn = tool_call.get("function") if isinstance(tool_call, dict) else None
|
648
|
+
if isinstance(fn, dict) and ("name" in fn):
|
649
|
+
name = fn.get("name", "interact")
|
650
|
+
args_raw = fn.get("arguments", "{}")
|
651
|
+
try:
|
652
|
+
import json as _json
|
653
|
+
args = _json.loads(args_raw) if isinstance(args_raw, str) else (args_raw or {})
|
654
|
+
if isinstance(args, dict):
|
655
|
+
parsed_decision = {"name": name, "parameters": args}
|
656
|
+
except Exception as _e:
|
657
|
+
parsed_decision = {"name": name, "parameters": {"arguments": args_raw}}
|
658
|
+
if not parsed_decision and isinstance(tool_call, dict):
|
659
|
+
if "name" in tool_call or "parameters" in tool_call:
|
660
|
+
parsed_decision = {
|
661
|
+
"name": tool_call.get("name", "interact"),
|
662
|
+
"parameters": tool_call.get("parameters", {}),
|
663
|
+
}
|
664
|
+
if parsed_decision:
|
665
|
+
decision = parsed_decision
|
666
|
+
# Clear failure flag on success
|
667
|
+
if self.agent_state.get("last_failure"):
|
668
|
+
self.agent_state["last_failure"] = None
|
669
|
+
params = decision.get('parameters', {}) if isinstance(decision, dict) else {}
|
670
|
+
actions = params.get('actions', []) if isinstance(params, dict) else []
|
671
|
+
num_actions = len(actions) if isinstance(actions, list) else 0
|
672
|
+
# Store metrics for tqdm postfix update in run_episode
|
673
|
+
self.agent_state["last_metrics"] = {
|
674
|
+
"tc": 1,
|
675
|
+
"act": num_actions,
|
676
|
+
"tok": completion_tokens,
|
677
|
+
"in": prompt_tokens,
|
678
|
+
"tps": f"{toks_per_sec}" if toks_per_sec is not None else "-",
|
679
|
+
}
|
680
|
+
else:
|
681
|
+
# Unrecognized tool_calls structure: do nothing, record failure
|
682
|
+
failure_msg = "Unrecognized tool_calls structure"
|
683
|
+
self.agent_state["last_failure"] = failure_msg
|
684
|
+
decision = {"name": "interact", "parameters": {"actions": [], "reasoning": failure_msg}}
|
685
|
+
if self.verbose:
|
686
|
+
print(f"🔍 DEBUG: {failure_msg}")
|
687
|
+
else:
|
688
|
+
# No tool calls: do nothing, record failure for next prompt
|
689
|
+
failure_msg = "No valid tool_calls in assistant message"
|
690
|
+
self.agent_state["last_failure"] = failure_msg
|
691
|
+
decision = {"name": "interact", "parameters": {"actions": [], "reasoning": failure_msg}}
|
692
|
+
# Store metrics for tqdm postfix update in run_episode
|
693
|
+
self.agent_state["last_metrics"] = {
|
694
|
+
"tc": 0,
|
695
|
+
"act": 0,
|
696
|
+
"tok": completion_tokens,
|
697
|
+
"in": prompt_tokens,
|
698
|
+
"tps": f"{toks_per_sec}" if toks_per_sec is not None else "-",
|
699
|
+
}
|
700
|
+
|
701
|
+
# Update agent state
|
702
|
+
self.agent_state["tool_calls_made"] += 1
|
703
|
+
|
704
|
+
# Add assistant response to history
|
705
|
+
assistant_message = {
|
706
|
+
"role": "assistant",
|
707
|
+
"content": raw_response
|
708
|
+
}
|
709
|
+
self.agent_state["message_history"].append(assistant_message)
|
710
|
+
|
711
|
+
if self.verbose:
|
712
|
+
print(f"🤖 LM Response (turn {turn}): {json.dumps(decision, indent=2)}")
|
713
|
+
print(f"📊 Response time: {llm_end - llm_start:.2f}s")
|
714
|
+
except Exception as e:
|
715
|
+
print(f"❌ Error in LM decide: {e}")
|
716
|
+
import traceback
|
717
|
+
traceback.print_exc()
|
718
|
+
# Record failure and do nothing this turn
|
719
|
+
failure_msg = f"Exception during decide: {str(e)}"
|
720
|
+
self.agent_state["last_failure"] = failure_msg
|
721
|
+
decision = {
|
722
|
+
"name": "interact",
|
723
|
+
"parameters": {
|
724
|
+
"actions": [],
|
725
|
+
"reasoning": failure_msg
|
726
|
+
}
|
727
|
+
}
|
728
|
+
|
729
|
+
return decision
|
730
|
+
|
731
|
+
def _parse_tool_response(self, raw_response: str) -> Dict[str, Any]:
|
732
|
+
"""Parse raw LM response to extract tool calls."""
|
733
|
+
# Try to parse JSON if present
|
734
|
+
try:
|
735
|
+
# Look for JSON in the response
|
736
|
+
import re
|
737
|
+
json_match = re.search(r'\{.*\}', raw_response, re.DOTALL)
|
738
|
+
if json_match:
|
739
|
+
data = json.loads(json_match.group())
|
740
|
+
if "name" in data:
|
741
|
+
return data
|
742
|
+
elif "function" in data:
|
743
|
+
return {
|
744
|
+
"name": data["function"].get("name", "interact"),
|
745
|
+
"parameters": data["function"].get("arguments", {})
|
746
|
+
}
|
747
|
+
except:
|
748
|
+
pass
|
749
|
+
|
750
|
+
# Fallback to text parsing
|
751
|
+
if "terminate" in raw_response.lower():
|
752
|
+
return {
|
753
|
+
"name": "terminate",
|
754
|
+
"parameters": {
|
755
|
+
"reason": "Agent decided to terminate"
|
756
|
+
}
|
757
|
+
}
|
758
|
+
|
759
|
+
# Try to extract actions from the response
|
760
|
+
actions = []
|
761
|
+
action_keywords = [
|
762
|
+
"move_up", "move_down", "move_left", "move_right", "do", "sleep",
|
763
|
+
"place_stone", "place_table", "place_furnace", "place_plant",
|
764
|
+
"make_wood_pickaxe", "make_stone_pickaxe", "make_iron_pickaxe",
|
765
|
+
"make_wood_sword", "make_stone_sword", "make_iron_sword"
|
766
|
+
]
|
767
|
+
|
768
|
+
for keyword in action_keywords:
|
769
|
+
if keyword in raw_response.lower():
|
770
|
+
actions.append(keyword)
|
771
|
+
|
772
|
+
if not actions:
|
773
|
+
actions = ["do"] # Default action
|
774
|
+
|
775
|
+
return {
|
776
|
+
"name": "interact",
|
777
|
+
"parameters": {
|
778
|
+
"actions": actions, # Return as array of actions
|
779
|
+
"reasoning": "Parsed from response"
|
780
|
+
}
|
781
|
+
}
|
782
|
+
|
783
|
+
def get_system_message(self) -> str:
|
784
|
+
"""Return system message for agent. Override in subclasses."""
|
785
|
+
return """You are an AI agent playing Crafter. Use the available tools to interact with the environment.
|
786
|
+
|
787
|
+
CRITICAL RULE: You MUST provide MULTIPLE actions (2-5) in EVERY interact() tool call!
|
788
|
+
|
789
|
+
The 'interact' function accepts a LIST of 1-5 actions. ALWAYS provide 2-5 actions for efficiency.
|
790
|
+
|
791
|
+
GOOD Examples (what you SHOULD do):
|
792
|
+
✓ interact(actions=["move_right", "move_right", "do"], reasoning="Move to tree and collect wood")
|
793
|
+
✓ interact(actions=["move_up", "move_up", "move_right", "do"], reasoning="Navigate to stone and mine it")
|
794
|
+
✓ interact(actions=["place_table", "make_wood_pickaxe", "move_left"], reasoning="Craft and continue exploring")
|
795
|
+
|
796
|
+
BAD Examples (what you should AVOID):
|
797
|
+
✗ interact(actions=["move_right"], reasoning="Move right") - TOO FEW ACTIONS!
|
798
|
+
✗ interact(actions=["do"], reasoning="Collect") - TOO FEW ACTIONS!
|
799
|
+
|
800
|
+
REMEMBER: Single actions waste time. Always plan 2-5 actions ahead and execute them together!"""
|
801
|
+
|
802
|
+
def format_observation(self, obs: Dict[str, Any]) -> str:
|
803
|
+
"""Format observation for agent. Override in subclasses."""
|
804
|
+
return str(obs)
|
805
|
+
|
806
|
+
|
807
|
+
# --- Crafter-specific ReAct Agent ---
|
808
|
+
class CrafterReActAgentWithLMSynth(BaseReActAgentWithLMSynth):
|
809
|
+
"""Crafter-specific ReAct agent with enhanced prompting for Synth models."""
|
810
|
+
|
811
|
+
def get_system_message(self) -> str:
|
812
|
+
"""Return Crafter-specific system message optimized for Synth models."""
|
813
|
+
override = os.getenv("CRAFTER_SYSTEM_PROMPT")
|
814
|
+
if override:
|
815
|
+
return override
|
816
|
+
return """You are CrafterAgent playing Crafter survival environment. Your goal is to unlock as many achievements as possible while staying alive.
|
817
|
+
|
818
|
+
You will see a semantic map view showing your surroundings. Use this to navigate toward resources.
|
819
|
+
|
820
|
+
Key mechanics:
|
821
|
+
• 'do' action: collect wood from trees, stone from deposits, food from cows/plants
|
822
|
+
• 'do' does nothing on grass/water - move to find resources first
|
823
|
+
• Craft progression: wood → table → wood_pickaxe → stone → stone_pickaxe → iron tools
|
824
|
+
• Sleep when energy low to restore and unlock wake_up achievement
|
825
|
+
• Use semantic map view to navigate toward resources you can see
|
826
|
+
|
827
|
+
Available actions: move_left, move_right, move_up, move_down, do, sleep, place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword, noop
|
828
|
+
|
829
|
+
KEY ACHIEVEMENTS TO UNLOCK:
|
830
|
+
Basic Resource Collection (PRIORITY #1):
|
831
|
+
- collect_wood: Move NEXT TO a tree, then use action="do" to collect wood
|
832
|
+
- collect_stone: Move NEXT TO stone, then use action="do" (requires wood_pickaxe in inventory)
|
833
|
+
- collect_coal: Move NEXT TO coal, then use action="do" (requires stone_pickaxe)
|
834
|
+
- collect_iron: Move NEXT TO iron, then use action="do" (requires stone_pickaxe)
|
835
|
+
- collect_diamond: Move NEXT TO diamond, then use action="do" (requires iron_pickaxe)
|
836
|
+
|
837
|
+
Tool Crafting (enables resource collection):
|
838
|
+
- make_wood_pickaxe: Use action="make_wood_pickaxe" when you have wood (unlocks ability to mine stone)
|
839
|
+
- make_stone_pickaxe: Use action="make_stone_pickaxe" when you have wood and stone (unlocks coal/iron mining)
|
840
|
+
- make_iron_pickaxe: Use action="make_iron_pickaxe" when you have wood, coal, and iron (unlocks diamond mining)
|
841
|
+
|
842
|
+
Weapon Crafting (for defense):
|
843
|
+
- make_wood_sword: Use action="make_wood_sword" when you have wood
|
844
|
+
- make_stone_sword: Use action="make_stone_sword" when you have wood and stone
|
845
|
+
- make_iron_sword: Use action="make_iron_sword" when you have wood, coal, and iron
|
846
|
+
|
847
|
+
Survival Actions:
|
848
|
+
- eat_plant: Use action="eat_plant" when food < 9 and you see a plant nearby
|
849
|
+
- eat_cow: Move NEXT TO cow, use action="do" to kill it, then action="eat_cow"
|
850
|
+
- collect_drink: Move NEXT TO water, then use action="drink" when drink < 9
|
851
|
+
- sleep: Use action="sleep" when energy < 5 (restores energy to 9)
|
852
|
+
|
853
|
+
Building/Placing:
|
854
|
+
- place_table: Use action="place_table" when you have wood (enables advanced crafting)
|
855
|
+
- place_furnace: Use action="place_furnace" when you have stone (for smelting)
|
856
|
+
- place_plant: Use action="place_plant" when you have sapling (grows into tree)
|
857
|
+
- place_stone: Use action="place_stone" when you have stone (creates barrier)
|
858
|
+
|
859
|
+
Combat:
|
860
|
+
- defeat_zombie: Move NEXT TO zombie, then use action="do" repeatedly to attack
|
861
|
+
- defeat_skeleton: Move NEXT TO skeleton, then use action="do" repeatedly to attack
|
862
|
+
|
863
|
+
CRITICAL: The action="do" is your INTERACTION button! Use it when adjacent to:
|
864
|
+
- Trees → get wood
|
865
|
+
- Stone/Coal/Iron/Diamond → mine resources (need appropriate pickaxe)
|
866
|
+
- Enemies → attack them
|
867
|
+
- Cows → kill for food
|
868
|
+
|
869
|
+
Simple Strategy:
|
870
|
+
1. Look for resources (trees, stones) in the semantic map
|
871
|
+
2. Move toward the nearest resource
|
872
|
+
3. When adjacent to a resource, use action="do" to collect it
|
873
|
+
4. If you have wood, try action="make_wood_pickaxe"
|
874
|
+
5. Repeat: find resources, move to them, use "do"
|
875
|
+
|
876
|
+
Critical Gameplay Tips:
|
877
|
+
- You must be ADJACENT (one tile away) to objects to interact with them
|
878
|
+
- Use "do" when next to: trees (for wood), stone (for stone), coal, iron, diamond
|
879
|
+
- Use "do" to attack zombies/skeletons when adjacent
|
880
|
+
- First priority: Find a tree, move next to it, then use "do" to collect wood
|
881
|
+
- Wood is essential for crafting your first pickaxe
|
882
|
+
- With wood_pickaxe you can mine stone, with stone_pickaxe you can mine iron, etc.
|
883
|
+
|
884
|
+
CRITICAL INSTRUCTION: You MUST ALWAYS provide MULTIPLE actions (2-5) in EVERY interact() tool call!
|
885
|
+
|
886
|
+
The 'interact' function accepts a LIST of 1-5 actions. NEVER use single actions - always plan 2-5 actions ahead!
|
887
|
+
|
888
|
+
MANDATORY action sequences (ALWAYS use multiple):
|
889
|
+
✓ interact(actions=["move_right", "move_right", "do"], reasoning="Move to tree and collect wood")
|
890
|
+
✓ interact(actions=["move_up", "move_up", "move_right", "do"], reasoning="Navigate and collect")
|
891
|
+
✓ interact(actions=["place_table", "make_wood_pickaxe", "move_left", "move_left"], reasoning="Craft and explore")
|
892
|
+
✓ interact(actions=["do", "move_right", "do", "move_right", "do"], reasoning="Collect multiple resources")
|
893
|
+
|
894
|
+
FORBIDDEN (NEVER do this):
|
895
|
+
✗ interact(actions=["move_right"], ...) - WRONG! Too few actions!
|
896
|
+
✗ interact(actions=["do"], ...) - WRONG! Too few actions!
|
897
|
+
|
898
|
+
RULE: If you use less than 2 actions, you are playing inefficiently. Always think 2-5 steps ahead!
|
899
|
+
|
900
|
+
Key Strategy:
|
901
|
+
1. Plan a sequence of moves to reach resources
|
902
|
+
2. Execute multiple moves in one tool call (e.g., ["move_right", "move_right", "move_up"])
|
903
|
+
3. When adjacent to a resource, use "do" to collect it
|
904
|
+
4. Chain crafting actions together (e.g., ["place_table", "make_wood_pickaxe"])
|
905
|
+
|
906
|
+
Remember:
|
907
|
+
- Use "do" when ADJACENT to trees (for wood), stones, or other resources
|
908
|
+
- Collect wood FIRST before trying to craft anything
|
909
|
+
- Be efficient - use multiple actions per tool call!
|
910
|
+
- Focus on unlocking achievements by collecting resources and crafting items."""
|
911
|
+
|
912
|
+
def format_observation(self, obs: Dict[str, Any]) -> str:
|
913
|
+
"""Format Crafter observation with semantic map view."""
|
914
|
+
# Get semantic map view
|
915
|
+
semantic_view = format_semantic_map_view_v2(obs, view_size=7)
|
916
|
+
|
917
|
+
# Extract key information
|
918
|
+
inventory = obs.get('inventory', {})
|
919
|
+
# Try both possible keys for achievements
|
920
|
+
achievements = obs.get('achievements_status', obs.get('achievements_info', {}))
|
921
|
+
health = obs.get('health', 10)
|
922
|
+
food = obs.get('food', 10)
|
923
|
+
drink = obs.get('drink', 10)
|
924
|
+
energy = obs.get('energy', 10)
|
925
|
+
|
926
|
+
# Count achievements
|
927
|
+
achieved = sum(1 for v in achievements.values() if v)
|
928
|
+
total_achievements = len(achievements)
|
929
|
+
|
930
|
+
# Format inventory (only show non-zero items)
|
931
|
+
inv_items = []
|
932
|
+
for item, count in inventory.items():
|
933
|
+
if count > 0:
|
934
|
+
inv_items.append(f"{item}: {count}")
|
935
|
+
inv_str = ", ".join(inv_items) if inv_items else "empty"
|
936
|
+
|
937
|
+
# List unlocked achievements
|
938
|
+
unlocked = [k for k, v in achievements.items() if v]
|
939
|
+
unlocked_str = ", ".join(unlocked) if unlocked else "none"
|
940
|
+
|
941
|
+
# Recent achievements (from info if available)
|
942
|
+
recent_str = ""
|
943
|
+
|
944
|
+
suppress_reminder = os.getenv("CRAFTER_SUPPRESS_OBS_REMINDER")
|
945
|
+
base = (
|
946
|
+
f"=== SEMANTIC MAP VIEW (7x7) ===\n"
|
947
|
+
f"{semantic_view}\n\n"
|
948
|
+
f"=== STATUS ===\n"
|
949
|
+
f"Health: {health}/10 | Food: {food}/10 | Drink: {drink}/10 | Energy: {energy}/10\n"
|
950
|
+
f"Inventory: {inv_str}\n"
|
951
|
+
f"Achievements: {achieved}/{total_achievements} unlocked\n"
|
952
|
+
f"Unlocked: {unlocked_str}\n"
|
953
|
+
f"{recent_str}\n\n"
|
954
|
+
f"What do you see in the map? What actions should you take? "
|
955
|
+
)
|
956
|
+
if suppress_reminder:
|
957
|
+
return base
|
958
|
+
return (
|
959
|
+
base
|
960
|
+
+ "\n\nREMINDER: You MUST provide 2-5 actions in your interact() tool call. Plan multiple steps ahead!\n"
|
961
|
+
+ 'Example: interact(actions=["move_right", "move_right", "do"], reasoning="Move to tree and collect wood")'
|
962
|
+
)
|
963
|
+
|
964
|
+
|
965
|
+
async def run_episode(
|
966
|
+
episode_id: int,
|
967
|
+
config: CrafterConfig,
|
968
|
+
session_tracer: Optional[SessionTracer] = None,
|
969
|
+
progress_bar: Optional[tqdm] = None,
|
970
|
+
quiet: bool = False,
|
971
|
+
model_params: Optional[Dict[str, Any]] = None
|
972
|
+
):
|
973
|
+
"""Run a single episode."""
|
974
|
+
episode_start_time = time.time()
|
975
|
+
|
976
|
+
# Create agent - always disable verbose for cleaner output
|
977
|
+
agent = CrafterReActAgentWithLMSynth(
|
978
|
+
model_name=config.model_name,
|
979
|
+
max_turns=config.max_turns,
|
980
|
+
verbose=False, # Always disable verbose logging in agent
|
981
|
+
tracer=session_tracer,
|
982
|
+
episode_id=episode_id,
|
983
|
+
quiet=True, # Always use quiet mode for agent
|
984
|
+
model_params=model_params
|
985
|
+
)
|
986
|
+
|
987
|
+
# Initialize environment
|
988
|
+
async with AsyncClient(base_url=config.service_base_url) as client:
|
989
|
+
try:
|
990
|
+
# Initialize environment with unique seed for each episode
|
991
|
+
# Use simple sequential seeds: 1, 2, 3, 4, etc.
|
992
|
+
episode_seed = episode_id + 1 # Start from 1 instead of 0
|
993
|
+
|
994
|
+
init_response = await retry_http_request(
|
995
|
+
client, "POST", "/env/CrafterClassic/initialize",
|
996
|
+
json={
|
997
|
+
"config": {
|
998
|
+
"difficulty": config.difficulty,
|
999
|
+
"seed": episode_seed
|
1000
|
+
}
|
1001
|
+
}
|
1002
|
+
)
|
1003
|
+
|
1004
|
+
init_data = init_response.json()
|
1005
|
+
instance_id = init_data["env_id"]
|
1006
|
+
obs = init_data["observation"]
|
1007
|
+
|
1008
|
+
# Start initial timestep and send initial observation as message
|
1009
|
+
if session_tracer:
|
1010
|
+
async with session_tracer.timestep("init", turn_number=0):
|
1011
|
+
obs_msg = create_message(
|
1012
|
+
compress_observation_for_trace(obs),
|
1013
|
+
"observation",
|
1014
|
+
f"crafter_env_{instance_id}",
|
1015
|
+
0
|
1016
|
+
)
|
1017
|
+
await session_tracer.record_message(
|
1018
|
+
content=obs_msg.content,
|
1019
|
+
message_type=obs_msg.message_type
|
1020
|
+
)
|
1021
|
+
|
1022
|
+
# Run episode
|
1023
|
+
episode_reward = 0
|
1024
|
+
termination_reason = None
|
1025
|
+
step_results = []
|
1026
|
+
consecutive_no_tool_calls = 0
|
1027
|
+
|
1028
|
+
# Create progress bar for this episode
|
1029
|
+
episode_progress = tqdm(
|
1030
|
+
total=config.max_turns,
|
1031
|
+
desc=f"Episode {episode_id}",
|
1032
|
+
position=episode_id,
|
1033
|
+
leave=True,
|
1034
|
+
ncols=100
|
1035
|
+
)
|
1036
|
+
|
1037
|
+
for turn in range(config.max_turns):
|
1038
|
+
episode_progress.update(1)
|
1039
|
+
|
1040
|
+
# Use timestep context for this turn
|
1041
|
+
timestep_name = f"turn_{turn+1}"
|
1042
|
+
async with session_tracer.timestep(timestep_name, turn_number=turn+1) if session_tracer else asyncio.nullcontext():
|
1043
|
+
# Get agent decision
|
1044
|
+
obs_formatted = agent.format_observation(obs)
|
1045
|
+
system_msg = agent.get_system_message()
|
1046
|
+
|
1047
|
+
decision = await agent.decide(obs_formatted, system_msg, turn)
|
1048
|
+
# Update tqdm postfix with latest metrics from agent
|
1049
|
+
try:
|
1050
|
+
metrics = agent.agent_state.get("last_metrics")
|
1051
|
+
if isinstance(metrics, dict):
|
1052
|
+
episode_progress.set_postfix(metrics, refresh=False)
|
1053
|
+
except Exception:
|
1054
|
+
pass
|
1055
|
+
|
1056
|
+
# Handle termination
|
1057
|
+
if decision["name"] == "terminate":
|
1058
|
+
termination_reason = decision["parameters"]["reason"]
|
1059
|
+
break
|
1060
|
+
|
1061
|
+
# Detect consecutive no-tool-call responses and abort after 3
|
1062
|
+
decision_params = decision.get("parameters") if isinstance(decision, dict) else None
|
1063
|
+
decision_actions = decision_params.get("actions", []) if isinstance(decision_params, dict) else []
|
1064
|
+
if decision.get("name") == "interact" and isinstance(decision_actions, list) and len(decision_actions) == 0:
|
1065
|
+
consecutive_no_tool_calls += 1
|
1066
|
+
print(f"🔍 DEBUG: consecutive_no_tool_calls={consecutive_no_tool_calls}")
|
1067
|
+
else:
|
1068
|
+
consecutive_no_tool_calls = 0
|
1069
|
+
if consecutive_no_tool_calls >= 3:
|
1070
|
+
raise RuntimeError("Aborting episode due to 3 consecutive no-tool-calls from the model")
|
1071
|
+
|
1072
|
+
# Execute actions in sequence
|
1073
|
+
actions = decision["parameters"].get("actions", []) if isinstance(decision.get("parameters"), dict) else []
|
1074
|
+
|
1075
|
+
# Ensure control variables are defined even if no actions are taken this turn
|
1076
|
+
done = False
|
1077
|
+
reward = 0.0
|
1078
|
+
info = {}
|
1079
|
+
|
1080
|
+
# Define action mapping
|
1081
|
+
CRAFTER_ACTION_MAP = {
|
1082
|
+
"noop": 0,
|
1083
|
+
"move_left": 1,
|
1084
|
+
"move_right": 2,
|
1085
|
+
"move_up": 3,
|
1086
|
+
"move_down": 4,
|
1087
|
+
"do": 5,
|
1088
|
+
"sleep": 6,
|
1089
|
+
"place_stone": 7,
|
1090
|
+
"place_table": 8,
|
1091
|
+
"place_furnace": 9,
|
1092
|
+
"place_plant": 10,
|
1093
|
+
"make_wood_pickaxe": 11,
|
1094
|
+
"make_stone_pickaxe": 12,
|
1095
|
+
"make_iron_pickaxe": 13,
|
1096
|
+
"make_wood_sword": 14,
|
1097
|
+
"make_stone_sword": 15,
|
1098
|
+
"make_iron_sword": 16,
|
1099
|
+
}
|
1100
|
+
|
1101
|
+
# Execute each action in the sequence (may be empty)
|
1102
|
+
for action in actions:
|
1103
|
+
# Convert action name to integer
|
1104
|
+
action_int = CRAFTER_ACTION_MAP.get(action, 0) # Default to noop
|
1105
|
+
|
1106
|
+
# Get state before action
|
1107
|
+
state_before = {"observation": obs} if 'obs' in locals() else {}
|
1108
|
+
prev_obs = obs.copy()
|
1109
|
+
|
1110
|
+
# Step environment
|
1111
|
+
step_response = await retry_http_request(
|
1112
|
+
client, "POST", "/env/CrafterClassic/step",
|
1113
|
+
json={
|
1114
|
+
"env_id": instance_id,
|
1115
|
+
"action": {"tool_calls": [{"tool": "interact", "args": {"action": action_int}}]}
|
1116
|
+
}
|
1117
|
+
)
|
1118
|
+
step_data = step_response.json()
|
1119
|
+
|
1120
|
+
# Check if response has expected structure
|
1121
|
+
if "observation" not in step_data:
|
1122
|
+
print(f"\n❌ Error: Missing observation in step response. Keys: {list(step_data.keys())}")
|
1123
|
+
if "error" in step_data:
|
1124
|
+
print(f" Error message: {step_data['error']}")
|
1125
|
+
# Try to recover or break
|
1126
|
+
break
|
1127
|
+
|
1128
|
+
obs = step_data["observation"]
|
1129
|
+
reward = step_data.get("reward", 0) # Default to 0 if None
|
1130
|
+
done = step_data.get("done", False) # Default to False if None
|
1131
|
+
info = step_data.get("info", {})
|
1132
|
+
|
1133
|
+
# Calculate achievement reward if not provided by service
|
1134
|
+
if reward == 0 or reward is None:
|
1135
|
+
# Check for newly unlocked achievements
|
1136
|
+
if 'achievements_status' in obs and 'achievements_status' in prev_obs:
|
1137
|
+
prev_achievements = prev_obs['achievements_status']
|
1138
|
+
curr_achievements = obs['achievements_status']
|
1139
|
+
new_unlocks = sum(1 for k in curr_achievements
|
1140
|
+
if curr_achievements.get(k) and not prev_achievements.get(k))
|
1141
|
+
if new_unlocks > 0:
|
1142
|
+
reward = float(new_unlocks) # +1 for each new achievement
|
1143
|
+
|
1144
|
+
if reward is not None:
|
1145
|
+
episode_reward += reward
|
1146
|
+
|
1147
|
+
|
1148
|
+
# Record step result
|
1149
|
+
step_results.append({
|
1150
|
+
"turn": turn,
|
1151
|
+
"action": action,
|
1152
|
+
"reward": reward,
|
1153
|
+
"done": done,
|
1154
|
+
"info": info
|
1155
|
+
})
|
1156
|
+
|
1157
|
+
# Record environment event for hooks to catch
|
1158
|
+
if session_tracer:
|
1159
|
+
# Create environment event with state transition
|
1160
|
+
env_event = EnvironmentEvent(
|
1161
|
+
time_record=TimeRecord(
|
1162
|
+
event_time=time.time(),
|
1163
|
+
message_time=turn
|
1164
|
+
),
|
1165
|
+
system_instance_id=f"crafter_env_{instance_id}",
|
1166
|
+
system_state_before={"public_state": prev_obs},
|
1167
|
+
system_state_after={"public_state": obs},
|
1168
|
+
reward=reward, # This now includes calculated achievement rewards
|
1169
|
+
terminated=done,
|
1170
|
+
metadata={
|
1171
|
+
"action": action,
|
1172
|
+
"action_int": action_int,
|
1173
|
+
"info": info
|
1174
|
+
}
|
1175
|
+
)
|
1176
|
+
await session_tracer.record_event(env_event)
|
1177
|
+
|
1178
|
+
# Also record runtime event for invalid action detection
|
1179
|
+
runtime_event = RuntimeEvent(
|
1180
|
+
time_record=TimeRecord(
|
1181
|
+
event_time=time.time(),
|
1182
|
+
message_time=turn
|
1183
|
+
),
|
1184
|
+
system_instance_id=f"crafter_runtime_{instance_id}",
|
1185
|
+
actions=[action_int],
|
1186
|
+
metadata={
|
1187
|
+
"action_name": action,
|
1188
|
+
"action_int": action_int,
|
1189
|
+
"reward": reward,
|
1190
|
+
"state_before": state_before,
|
1191
|
+
"state_after": {"observation": obs}
|
1192
|
+
}
|
1193
|
+
)
|
1194
|
+
await session_tracer.record_event(runtime_event)
|
1195
|
+
|
1196
|
+
if done:
|
1197
|
+
break
|
1198
|
+
|
1199
|
+
# After all actions (or none), send final observation message
|
1200
|
+
if session_tracer:
|
1201
|
+
obs_msg = create_message(
|
1202
|
+
compress_observation_for_trace(obs),
|
1203
|
+
"observation",
|
1204
|
+
f"crafter_env_{instance_id}",
|
1205
|
+
turn + 1
|
1206
|
+
)
|
1207
|
+
await session_tracer.record_message(
|
1208
|
+
content=obs_msg.content,
|
1209
|
+
message_type=obs_msg.message_type
|
1210
|
+
)
|
1211
|
+
|
1212
|
+
if done:
|
1213
|
+
break
|
1214
|
+
|
1215
|
+
|
1216
|
+
# Close progress bar
|
1217
|
+
episode_progress.close()
|
1218
|
+
|
1219
|
+
# Terminate instance
|
1220
|
+
terminate_response = await retry_http_request(
|
1221
|
+
client, "POST", f"/env/CrafterClassic/terminate",
|
1222
|
+
json={"env_id": instance_id}
|
1223
|
+
)
|
1224
|
+
|
1225
|
+
except Exception as e:
|
1226
|
+
if 'episode_progress' in locals():
|
1227
|
+
episode_progress.close()
|
1228
|
+
print(f"\n❌ Episode {episode_id} failed: {e}")
|
1229
|
+
if config.verbose:
|
1230
|
+
import traceback
|
1231
|
+
traceback.print_exc()
|
1232
|
+
return {
|
1233
|
+
"episode_id": episode_id,
|
1234
|
+
"error": str(e),
|
1235
|
+
"duration": time.time() - episode_start_time
|
1236
|
+
}
|
1237
|
+
|
1238
|
+
# Extract final achievements
|
1239
|
+
final_achievements = []
|
1240
|
+
if obs and 'achievements_status' in obs:
|
1241
|
+
final_achievements = [k for k, v in obs['achievements_status'].items() if v]
|
1242
|
+
|
1243
|
+
# Return results
|
1244
|
+
return {
|
1245
|
+
"episode_id": episode_id,
|
1246
|
+
"total_reward": episode_reward,
|
1247
|
+
"steps": len(step_results),
|
1248
|
+
"termination_reason": termination_reason,
|
1249
|
+
"duration": time.time() - episode_start_time,
|
1250
|
+
"step_results": step_results,
|
1251
|
+
"achievements_unlocked": final_achievements
|
1252
|
+
}
|
1253
|
+
|
1254
|
+
|
1255
|
+
# --- Main ---
|
1256
|
+
async def main():
|
1257
|
+
"""Main entry point with v3 tracing."""
|
1258
|
+
parser = argparse.ArgumentParser(description="Run Crafter evaluation with LM Synth backend")
|
1259
|
+
parser.add_argument("--config", type=str, help="Path to TOML config file")
|
1260
|
+
parser.add_argument("--model", type=str, help="Model name (overrides config)")
|
1261
|
+
parser.add_argument("--episodes", type=int, help="Number of episodes (overrides config)")
|
1262
|
+
parser.add_argument("--max-steps", type=int, help="Max steps per episode (overrides config)")
|
1263
|
+
parser.add_argument("--difficulty", type=str, choices=["easy", "normal", "hard"], help="Difficulty override")
|
1264
|
+
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
|
1265
|
+
parser.add_argument("--quiet", action="store_true", help="Suppress most output except results")
|
1266
|
+
parser.add_argument("--no-traces", action="store_true", help="Disable trace saving")
|
1267
|
+
parser.add_argument("--analyze", action="store_true", help="Analyze traces after running")
|
1268
|
+
parser.add_argument("--skip-warmup", action="store_true", help="Skip model warmup")
|
1269
|
+
parser.add_argument("--no-daemon", action="store_true", help="Don't start sqld daemon (assumes it's already running)")
|
1270
|
+
|
1271
|
+
# Qwen3 thinking mode flags (mutually exclusive)
|
1272
|
+
think_group = parser.add_mutually_exclusive_group()
|
1273
|
+
think_group.add_argument("--think", dest="enable_thinking", action="store_true", help="Enable Qwen3 thinking mode (chat_template_kwargs.enable_thinking=True)")
|
1274
|
+
think_group.add_argument("--no-think", dest="enable_thinking", action="store_false", help="Disable Qwen3 thinking mode (chat_template_kwargs.enable_thinking=False)")
|
1275
|
+
parser.set_defaults(enable_thinking=None)
|
1276
|
+
|
1277
|
+
# Model parameter arguments
|
1278
|
+
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for model responses (default: 0.7)")
|
1279
|
+
parser.add_argument("--max-tokens", type=int, default=512, help="Maximum tokens to generate (default: 512)")
|
1280
|
+
parser.add_argument("--top-p", type=float, default=1.0, help="Top-p sampling parameter (default: 1.0)")
|
1281
|
+
parser.add_argument("--frequency-penalty", type=float, default=0.0, help="Frequency penalty (default: 0.0)")
|
1282
|
+
parser.add_argument("--presence-penalty", type=float, default=0.0, help="Presence penalty (default: 0.0)")
|
1283
|
+
parser.add_argument("--tool-choice", type=str, choices=["auto", "required", "none"], default="auto", help="Tool choice mode (default: auto)")
|
1284
|
+
|
1285
|
+
args = parser.parse_args()
|
1286
|
+
|
1287
|
+
# Load configuration
|
1288
|
+
config = CrafterConfig(args.config)
|
1289
|
+
|
1290
|
+
# Setup Synth environment variables
|
1291
|
+
setup_synth_environment()
|
1292
|
+
|
1293
|
+
# Clean up old files to keep directory clean
|
1294
|
+
if config.auto_cleanup:
|
1295
|
+
cleanup_old_files()
|
1296
|
+
|
1297
|
+
# Apply command-line overrides
|
1298
|
+
if args.model:
|
1299
|
+
config.model_name = args.model
|
1300
|
+
if args.episodes:
|
1301
|
+
config.num_instances = args.episodes
|
1302
|
+
if args.max_steps:
|
1303
|
+
config.max_turns = args.max_steps
|
1304
|
+
if args.difficulty:
|
1305
|
+
config.difficulty = args.difficulty
|
1306
|
+
if args.verbose:
|
1307
|
+
config.verbose = True
|
1308
|
+
if args.quiet:
|
1309
|
+
config.quiet = True
|
1310
|
+
if not args.verbose: # Don't show this if verbose is also on
|
1311
|
+
print("🔇 Quiet mode enabled - suppressing verbose logs")
|
1312
|
+
else:
|
1313
|
+
config.quiet = False
|
1314
|
+
if args.no_daemon:
|
1315
|
+
config.start_sqld_daemon = False
|
1316
|
+
|
1317
|
+
# Environment overrides for model parameters (fail-fast on bad values)
|
1318
|
+
env_temp = os.getenv("CRAFTER_TEMPERATURE")
|
1319
|
+
if env_temp is not None:
|
1320
|
+
args.temperature = float(env_temp)
|
1321
|
+
env_max_tok = os.getenv("CRAFTER_MAX_TOKENS")
|
1322
|
+
if env_max_tok is not None:
|
1323
|
+
args.max_tokens = int(env_max_tok)
|
1324
|
+
env_tool_choice = os.getenv("CRAFTER_TOOL_CHOICE")
|
1325
|
+
if env_tool_choice is not None:
|
1326
|
+
if env_tool_choice not in {"auto", "required", "none"}:
|
1327
|
+
raise ValueError(f"Invalid CRAFTER_TOOL_CHOICE: {env_tool_choice}")
|
1328
|
+
args.tool_choice = env_tool_choice
|
1329
|
+
env_top_p = os.getenv("CRAFTER_TOP_P")
|
1330
|
+
if env_top_p is not None:
|
1331
|
+
args.top_p = float(env_top_p)
|
1332
|
+
env_freq_pen = os.getenv("CRAFTER_FREQUENCY_PENALTY")
|
1333
|
+
if env_freq_pen is not None:
|
1334
|
+
args.frequency_penalty = float(env_freq_pen)
|
1335
|
+
env_pres_pen = os.getenv("CRAFTER_PRESENCE_PENALTY")
|
1336
|
+
if env_pres_pen is not None:
|
1337
|
+
args.presence_penalty = float(env_pres_pen)
|
1338
|
+
|
1339
|
+
# Resolve stop-after-tool-calls from environment (wrapper sets this)
|
1340
|
+
try:
|
1341
|
+
_satc = int(os.getenv("CRAFTER_STOP_AFTER_TOOL_CALLS", "1"))
|
1342
|
+
except Exception:
|
1343
|
+
_satc = 1
|
1344
|
+
_extra_body = {"stop_after_tool_calls": _satc} if _satc and _satc > 0 else {}
|
1345
|
+
|
1346
|
+
# Create model parameters dictionary from command line arguments
|
1347
|
+
model_params = {
|
1348
|
+
"temperature": args.temperature,
|
1349
|
+
"max_tokens": args.max_tokens,
|
1350
|
+
"top_p": args.top_p,
|
1351
|
+
"frequency_penalty": args.frequency_penalty,
|
1352
|
+
"presence_penalty": args.presence_penalty,
|
1353
|
+
"tool_choice": args.tool_choice,
|
1354
|
+
# Request early stop after N tool call blocks to avoid spillover
|
1355
|
+
"extra_body": _extra_body
|
1356
|
+
}
|
1357
|
+
# Optionally carry thinking mode through to LM config
|
1358
|
+
if args.enable_thinking is not None:
|
1359
|
+
model_params["enable_thinking"] = args.enable_thinking
|
1360
|
+
|
1361
|
+
# Configure logging based on quiet mode
|
1362
|
+
setup_logging(quiet_mode=config.quiet)
|
1363
|
+
|
1364
|
+
# Display configuration (only if not in quiet mode)
|
1365
|
+
if not config.quiet:
|
1366
|
+
print(f"🎮 Crafter ReAct Agent Evaluation (LM with Synth Backend - v3)")
|
1367
|
+
print(f"Model: {config.model_name}")
|
1368
|
+
print(f"Model Parameters:")
|
1369
|
+
print(f" Temperature: {model_params['temperature']}")
|
1370
|
+
print(f" Max Tokens: {model_params['max_tokens']}")
|
1371
|
+
print(f" Top-p: {model_params['top_p']}")
|
1372
|
+
print(f" Frequency Penalty: {model_params['frequency_penalty']}")
|
1373
|
+
print(f" Presence Penalty: {model_params['presence_penalty']}")
|
1374
|
+
print(f"Service: {config.service_base_url}")
|
1375
|
+
print(f"Instances: {config.num_instances}")
|
1376
|
+
print(f"Max Turns: {config.max_turns}")
|
1377
|
+
print(f"Difficulty: {config.difficulty}")
|
1378
|
+
print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
1379
|
+
print("=" * 50)
|
1380
|
+
|
1381
|
+
if args.no_traces:
|
1382
|
+
config.save_traces = False
|
1383
|
+
config.enable_v3_tracing = False
|
1384
|
+
if args.analyze:
|
1385
|
+
config.analyze_traces = True
|
1386
|
+
if args.skip_warmup:
|
1387
|
+
config.warmup_model = False
|
1388
|
+
|
1389
|
+
# Ensure model is specified
|
1390
|
+
if not config.model_name:
|
1391
|
+
parser.error("Model name must be specified via --model or config file")
|
1392
|
+
|
1393
|
+
|
1394
|
+
|
1395
|
+
# Test service health
|
1396
|
+
async with AsyncClient(base_url=config.service_base_url) as client:
|
1397
|
+
try:
|
1398
|
+
health_resp = await retry_http_request(client, "GET", "/health")
|
1399
|
+
health_data = health_resp.json()
|
1400
|
+
print(f"✅ Crafter service is healthy: {health_data}")
|
1401
|
+
except Exception as e:
|
1402
|
+
print(f"❌ Failed to connect to Crafter service: {e}")
|
1403
|
+
return
|
1404
|
+
|
1405
|
+
# Warm up the model if requested
|
1406
|
+
if config.warmup_model and not args.skip_warmup:
|
1407
|
+
print(f"\n🔥 Warming up {config.model_name} on Synth backend...")
|
1408
|
+
try:
|
1409
|
+
synth_base_url = os.getenv('SYNTH_BASE_URL')# or os.getenv('MODAL_BASE_URL')
|
1410
|
+
synth_api_key = os.getenv('SYNTH_API_KEY')# or os.getenv('MODAL_API_KEY')
|
1411
|
+
if synth_base_url and synth_api_key:
|
1412
|
+
synth_config = SynthConfig(
|
1413
|
+
base_url=synth_base_url,
|
1414
|
+
api_key=synth_api_key,
|
1415
|
+
timeout=config.warmup_timeout # Use configurable timeout
|
1416
|
+
)
|
1417
|
+
warmed = await warmup_synth_model(config.model_name, synth_config)
|
1418
|
+
if warmed:
|
1419
|
+
print("✅ Model warmed up successfully!")
|
1420
|
+
else:
|
1421
|
+
print("⚠️ Warmup did not complete; continuing anyway...")
|
1422
|
+
else:
|
1423
|
+
print("⚠️ Missing SYNTH_BASE_URL or SYNTH_API_KEY, skipping warmup")
|
1424
|
+
except Exception as e:
|
1425
|
+
print(f"⚠️ Warmup failed: {e}")
|
1426
|
+
print("Continuing anyway...")
|
1427
|
+
|
1428
|
+
# Set up v3 tracing if enabled
|
1429
|
+
trace_manager = None
|
1430
|
+
experiment_ctx = None
|
1431
|
+
sqld_daemon = None
|
1432
|
+
|
1433
|
+
if config.enable_v3_tracing:
|
1434
|
+
# Create trace directory first
|
1435
|
+
os.makedirs(config.v3_trace_dir, exist_ok=True)
|
1436
|
+
|
1437
|
+
# Start sqld daemon if requested
|
1438
|
+
if config.start_sqld_daemon:
|
1439
|
+
print(f"\n🚀 Starting sqld daemon for v3 tracing...")
|
1440
|
+
sqld_daemon = SqldDaemon(db_path=config.turso_db_path)
|
1441
|
+
sqld_daemon.__enter__() # Start the daemon
|
1442
|
+
await asyncio.sleep(2) # Give it time to start
|
1443
|
+
print("✅ sqld daemon started")
|
1444
|
+
|
1445
|
+
# Initialize trace manager with proper URL format
|
1446
|
+
# If SQLD_DB_PATH is a directory managed by sqld, use its data file
|
1447
|
+
_db_path = config.turso_db_path
|
1448
|
+
if os.path.isdir(_db_path):
|
1449
|
+
_candidate = os.path.join(_db_path, "dbs", "default", "data")
|
1450
|
+
if os.path.exists(_candidate):
|
1451
|
+
_db_path = _candidate
|
1452
|
+
db_url = f"sqlite+aiosqlite:///{os.path.abspath(_db_path)}"
|
1453
|
+
trace_manager = AsyncSQLTraceManager(db_url=db_url)
|
1454
|
+
await trace_manager.initialize()
|
1455
|
+
|
1456
|
+
# Create experiment context
|
1457
|
+
experiment_ctx = await create_experiment_context(
|
1458
|
+
db_manager=trace_manager,
|
1459
|
+
experiment_name=f"crafter_lm_synth_{config.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
1460
|
+
description=f"Crafter LM Synth experiment with {config.model_name} on {config.difficulty} difficulty, using LM class with v3 tracing"
|
1461
|
+
)
|
1462
|
+
|
1463
|
+
print(f"\n📊 V3 Tracing enabled. Traces will be saved to: {config.turso_db_path}")
|
1464
|
+
print(f" Experiment: {experiment_ctx['experiment_name']}")
|
1465
|
+
|
1466
|
+
# Run episodes with bounded concurrency using asyncio.Semaphore
|
1467
|
+
# Control concurrency with env var CRAFTER_CONCURRENCY (default 5)
|
1468
|
+
try:
|
1469
|
+
_conc_str = os.getenv("CRAFTER_CONCURRENCY")
|
1470
|
+
max_concurrency = int(_conc_str) if _conc_str else 5
|
1471
|
+
except Exception:
|
1472
|
+
max_concurrency = 5
|
1473
|
+
concurrency_limiter = asyncio.Semaphore(max_concurrency)
|
1474
|
+
|
1475
|
+
print(f"\n🚀 Running {config.num_instances} episodes (concurrency={max_concurrency})...")
|
1476
|
+
|
1477
|
+
episode_seeds = [] # Track seeds used for each episode
|
1478
|
+
|
1479
|
+
# Prepare episode tasks
|
1480
|
+
episode_tasks = []
|
1481
|
+
session_ids = []
|
1482
|
+
|
1483
|
+
for i in range(config.num_instances):
|
1484
|
+
# Calculate episode seed for logging (simple sequential: 1, 2, 3, etc)
|
1485
|
+
episode_seed = i + 1
|
1486
|
+
episode_seeds.append(episode_seed)
|
1487
|
+
|
1488
|
+
# Create session tracer for this episode if v3 tracing is enabled
|
1489
|
+
session_tracer = None
|
1490
|
+
if config.enable_v3_tracing and trace_manager:
|
1491
|
+
session_tracer = SessionTracer(hooks=QUIET_HOOKS) # Use quiet hooks
|
1492
|
+
session_tracer.db = trace_manager # Use existing manager
|
1493
|
+
session_tracer._initialized = True
|
1494
|
+
|
1495
|
+
# Generate session ID
|
1496
|
+
session_id = f"crafter_episode_{i}_{uuid.uuid4().hex[:8]}"
|
1497
|
+
session_ids.append(session_id)
|
1498
|
+
|
1499
|
+
# Create episode task with proper session context
|
1500
|
+
async def run_episode_with_session(ep_id, cfg, tracer, pb, quiet, sess_id, model_params):
|
1501
|
+
if tracer:
|
1502
|
+
async with tracer.session(
|
1503
|
+
session_id=sess_id,
|
1504
|
+
metadata={
|
1505
|
+
"episode_id": ep_id,
|
1506
|
+
"experiment_id": experiment_ctx['experiment_id'] if experiment_ctx else None
|
1507
|
+
}
|
1508
|
+
):
|
1509
|
+
return await run_episode(ep_id, cfg, tracer, pb, quiet, model_params)
|
1510
|
+
else:
|
1511
|
+
return await run_episode(ep_id, cfg, tracer, pb, quiet, model_params)
|
1512
|
+
|
1513
|
+
# Freeze per-iteration values to avoid late-binding bugs in closures
|
1514
|
+
this_tracer = session_tracer
|
1515
|
+
this_session_id = session_ids[i] if session_ids else None
|
1516
|
+
|
1517
|
+
async def _limited_episode(ep_idx=i, tracer=this_tracer, sess_id=this_session_id):
|
1518
|
+
async with concurrency_limiter:
|
1519
|
+
return await run_episode_with_session(
|
1520
|
+
ep_idx, config, tracer, None, args.quiet, sess_id, model_params
|
1521
|
+
)
|
1522
|
+
|
1523
|
+
episode_task = _limited_episode()
|
1524
|
+
episode_tasks.append(episode_task)
|
1525
|
+
|
1526
|
+
print(f"\n📤 Starting episodes...")
|
1527
|
+
start_time = time.time()
|
1528
|
+
|
1529
|
+
# Run all episodes in parallel and fail fast on first error
|
1530
|
+
try:
|
1531
|
+
results = await asyncio.gather(*episode_tasks, return_exceptions=False)
|
1532
|
+
except Exception as e:
|
1533
|
+
print(f"\n❌ Run aborted due to error: {e}")
|
1534
|
+
# Ensure resources are cleaned up before exiting
|
1535
|
+
if trace_manager:
|
1536
|
+
await trace_manager.close()
|
1537
|
+
if sqld_daemon:
|
1538
|
+
sqld_daemon.__exit__(None, None, None)
|
1539
|
+
print("\n✅ Stopped sqld daemon")
|
1540
|
+
raise
|
1541
|
+
|
1542
|
+
end_time = time.time()
|
1543
|
+
parallel_time = end_time - start_time
|
1544
|
+
|
1545
|
+
print(f"\n✅ Completed {len(episode_tasks)} episodes in {parallel_time:.2f} seconds")
|
1546
|
+
|
1547
|
+
# Process results and handle any exceptions
|
1548
|
+
successful_results = []
|
1549
|
+
failed_results = []
|
1550
|
+
|
1551
|
+
for i, result in enumerate(results):
|
1552
|
+
if isinstance(result, Exception):
|
1553
|
+
print(f"❌ Episode {i} failed: {result}")
|
1554
|
+
failed_results.append({"episode_id": i, "error": str(result)})
|
1555
|
+
else:
|
1556
|
+
successful_results.append(result)
|
1557
|
+
|
1558
|
+
# Link session to experiment if tracing enabled
|
1559
|
+
if config.enable_v3_tracing and trace_manager and experiment_ctx and i < len(session_ids):
|
1560
|
+
await trace_manager.link_session_to_experiment(
|
1561
|
+
session_ids[i],
|
1562
|
+
experiment_ctx['experiment_id']
|
1563
|
+
)
|
1564
|
+
|
1565
|
+
# Use successful results for analysis
|
1566
|
+
results = successful_results + failed_results
|
1567
|
+
|
1568
|
+
# Analyze results
|
1569
|
+
print("\n" + "=" * 50)
|
1570
|
+
print("📊 EVALUATION RESULTS")
|
1571
|
+
print("=" * 50)
|
1572
|
+
|
1573
|
+
successful_episodes = [r for r in results if 'error' not in r]
|
1574
|
+
failed_episodes = [r for r in results if 'error' in r]
|
1575
|
+
|
1576
|
+
if successful_episodes:
|
1577
|
+
total_reward = sum(r['total_reward'] for r in successful_episodes)
|
1578
|
+
total_steps = sum(r['steps'] for r in successful_episodes)
|
1579
|
+
avg_reward = total_reward / len(successful_episodes)
|
1580
|
+
avg_steps = total_steps / len(successful_episodes)
|
1581
|
+
|
1582
|
+
print(f"Episodes completed: {len(successful_episodes)}/{config.num_instances}")
|
1583
|
+
print(f"Failed episodes: {len(failed_episodes)}")
|
1584
|
+
print(f"Total reward: {total_reward:.2f}")
|
1585
|
+
print(f"Average reward per episode: {avg_reward:.2f}")
|
1586
|
+
print(f"Total steps: {total_steps}")
|
1587
|
+
print(f"Average steps per episode: {avg_steps:.2f}")
|
1588
|
+
|
1589
|
+
# Show seeds used
|
1590
|
+
if episode_seeds:
|
1591
|
+
print(f"\nSeeds used:")
|
1592
|
+
for i, seed in enumerate(episode_seeds[:len(successful_episodes)]):
|
1593
|
+
print(f" Episode {i}: seed {seed}")
|
1594
|
+
|
1595
|
+
# Extract unique achievements
|
1596
|
+
all_achievements = set()
|
1597
|
+
achievement_counts = defaultdict(int)
|
1598
|
+
|
1599
|
+
for result in successful_episodes:
|
1600
|
+
# Use the achievements_unlocked field we added
|
1601
|
+
if 'achievements_unlocked' in result:
|
1602
|
+
for achievement in result['achievements_unlocked']:
|
1603
|
+
all_achievements.add(achievement)
|
1604
|
+
achievement_counts[achievement] += 1
|
1605
|
+
|
1606
|
+
# Extract and count all actions from successful episodes
|
1607
|
+
action_counts = defaultdict(int)
|
1608
|
+
total_actions = 0
|
1609
|
+
|
1610
|
+
for result in successful_episodes:
|
1611
|
+
if 'step_results' in result:
|
1612
|
+
for step in result['step_results']:
|
1613
|
+
if 'action' in step:
|
1614
|
+
action_counts[step['action']] += 1
|
1615
|
+
total_actions += 1
|
1616
|
+
|
1617
|
+
print(f"Unique achievements unlocked: {len(all_achievements)}")
|
1618
|
+
if all_achievements:
|
1619
|
+
print("\nAchievements unlocked:")
|
1620
|
+
for achievement, count in sorted(achievement_counts.items()):
|
1621
|
+
print(f" - {achievement}: {count} episodes ({count/len(successful_episodes)*100:.1f}%)")
|
1622
|
+
|
1623
|
+
# Display action counts
|
1624
|
+
if action_counts:
|
1625
|
+
print(f"\nAction counts (total: {total_actions}):")
|
1626
|
+
for action, count in sorted(action_counts.items(), key=lambda x: x[1], reverse=True):
|
1627
|
+
percentage = count / total_actions * 100 if total_actions > 0 else 0
|
1628
|
+
print(f" - {action}: {count} ({percentage:.1f}%)")
|
1629
|
+
else:
|
1630
|
+
print("No successful episodes completed.")
|
1631
|
+
|
1632
|
+
# Save detailed results
|
1633
|
+
if config.save_detailed_results and config.enable_v3_tracing and trace_manager:
|
1634
|
+
# For v3, results are automatically saved in the database
|
1635
|
+
print(f"\n💾 Results available in Turso database: {config.turso_db_path}")
|
1636
|
+
print(f" Experiment ID: {experiment_ctx['experiment_id']}")
|
1637
|
+
print(f" Use the filter_traces_sft_turso.py script to extract fine-tuning data")
|
1638
|
+
elif config.save_detailed_results:
|
1639
|
+
# Fallback to JSON if no tracing
|
1640
|
+
results_file = f"crafter_lm_synth_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
1641
|
+
with open(results_file, 'w') as f:
|
1642
|
+
json.dump({
|
1643
|
+
'config': {
|
1644
|
+
'model': config.model_name,
|
1645
|
+
'episodes': config.num_instances,
|
1646
|
+
'max_steps': config.max_turns,
|
1647
|
+
'difficulty': config.difficulty,
|
1648
|
+
'backend': 'synth',
|
1649
|
+
'tracing': 'v3'
|
1650
|
+
},
|
1651
|
+
'results': results,
|
1652
|
+
'summary': {
|
1653
|
+
'successful_episodes': len(successful_episodes),
|
1654
|
+
'failed_episodes': len(failed_episodes),
|
1655
|
+
'total_reward': total_reward if successful_episodes else 0,
|
1656
|
+
'avg_reward': avg_reward if successful_episodes else 0,
|
1657
|
+
'unique_achievements': list(all_achievements) if successful_episodes else []
|
1658
|
+
}
|
1659
|
+
}, f, indent=2)
|
1660
|
+
print(f"\n💾 Detailed results saved to: {results_file}")
|
1661
|
+
|
1662
|
+
# Print a markdown row compatible with Environments/crafter.md tables
|
1663
|
+
if successful_episodes:
|
1664
|
+
# Columns: | model | trajectories | avg achievements | adj score | unique | steps sum | avg steps |
|
1665
|
+
model_label = config.model_name.replace("/", "/")
|
1666
|
+
trajectories = len(successful_episodes)
|
1667
|
+
avg_ach = avg_reward # our reward == achievements unlocked per episode
|
1668
|
+
|
1669
|
+
# Compute weighted scores (shaped and K-Score) from final achievements across episodes
|
1670
|
+
# K coefficients taken from crafter.md (representative weights)
|
1671
|
+
k_weights = {
|
1672
|
+
"collect_drink": 0.1,
|
1673
|
+
"collect_sapling": 0.1,
|
1674
|
+
"wake_up": 0.1,
|
1675
|
+
"collect_wood": 1.0,
|
1676
|
+
"collect_stone": 1.0,
|
1677
|
+
"eat_cow": 1.0,
|
1678
|
+
"defeat_zombie": 1.0,
|
1679
|
+
"defeat_skeleton": 1.0,
|
1680
|
+
"make_wood_pickaxe": 3.0,
|
1681
|
+
"place_table": 3.0,
|
1682
|
+
"collect_coal": 3.0,
|
1683
|
+
"make_stone_pickaxe": 10.0,
|
1684
|
+
"place_furnace": 10.0,
|
1685
|
+
"collect_iron": 10.0,
|
1686
|
+
"make_stone_sword": 10.0,
|
1687
|
+
"make_wood_sword": 3.0,
|
1688
|
+
"place_plant": 0.1,
|
1689
|
+
}
|
1690
|
+
|
1691
|
+
# Aggregate final achievements across successful episodes
|
1692
|
+
from collections import Counter
|
1693
|
+
ach_counter: Counter[str] = Counter()
|
1694
|
+
for ep in successful_episodes:
|
1695
|
+
for name in ep.get("achievements_unlocked", []):
|
1696
|
+
ach_counter[name] += 1
|
1697
|
+
|
1698
|
+
shaped_total = 0.0
|
1699
|
+
for name, count in ach_counter.items():
|
1700
|
+
k = k_weights.get(name, 1.0)
|
1701
|
+
shaped_total += k * count
|
1702
|
+
|
1703
|
+
# Shaped reward per episode average
|
1704
|
+
shaped_reward_avg = shaped_total / trajectories if trajectories > 0 else 0.0
|
1705
|
+
k_score_avg = shaped_reward_avg / 20.0 # normalize roughly to match table scale
|
1706
|
+
|
1707
|
+
unique = len(all_achievements)
|
1708
|
+
steps_sum = total_steps
|
1709
|
+
avg_steps_md = avg_steps
|
1710
|
+
print("\nMarkdown row:")
|
1711
|
+
print(f"| {model_label:<15} | {trajectories:7d} | {avg_ach:8.2f} | {shaped_reward_avg:13.3f} | {k_score_avg:12.3f} | {steps_sum:12.3f} | {avg_steps_md:8.3f} |")
|
1712
|
+
|
1713
|
+
# Cleanup
|
1714
|
+
if trace_manager:
|
1715
|
+
await trace_manager.close()
|
1716
|
+
|
1717
|
+
if sqld_daemon:
|
1718
|
+
sqld_daemon.__exit__(None, None, None)
|
1719
|
+
print("\n✅ Stopped sqld daemon")
|
1720
|
+
|
1721
|
+
|
1722
|
+
if __name__ == "__main__":
|
1723
|
+
asyncio.run(main())
|
1724
|
+
|
1725
|
+
|
1726
|
+
# === SEMANTIC MAP VIEW (15x15) ===
|
1727
|
+
# stone coal iron coal coal coal coal
|
1728
|
+
# stone stone iron coal coal coal coal
|
1729
|
+
# stone stone zombie coal coal iron iron
|
1730
|
+
# stone stone stone you stone iron iron
|
1731
|
+
# stone stone stone stone stone stone stone
|
1732
|
+
# stone stone stone stone stone stone stone
|
1733
|
+
# stone stone stone stone stone stone stone
|
1734
|
+
# Visible items: coal, iron, stone, zombie
|
1735
|
+
|
1736
|
+
# === STATUS ===
|
1737
|
+
# Health: 10/10 | Food: 10/10 | Drink: 10/10 | Energy: 10/10
|
1738
|
+
# Inventory: health: 9, food: 7, drink: 7, energy: 9, wood: 1, wood_pickaxe: 1
|
1739
|
+
# Achievements: 4/22 unlocked
|
1740
|
+
# Unlocked: collect_wood, make_wood_pickaxe, place_table, wake_up
|