synth-ai 0.2.4.dev8__py3-none-any.whl → 0.2.4.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (111) hide show
  1. synth_ai/cli/__init__.py +6 -0
  2. synth_ai/cli/demo.py +68 -9
  3. synth_ai/cli/rl_demo.py +137 -0
  4. synth_ai/cli/root.py +65 -0
  5. synth_ai/demos/core/__init__.py +1 -0
  6. synth_ai/demos/core/cli.py +621 -0
  7. synth_ai/demos/demo_task_apps/__init__.py +1 -0
  8. synth_ai/demos/demo_task_apps/core.py +374 -0
  9. synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
  10. synth_ai/demos/demo_task_apps/math/app.py +37 -0
  11. synth_ai/demos/demo_task_apps/math/config.toml +44 -0
  12. synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
  13. synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
  14. synth_ai/environments/examples/bandit/__init__.py +33 -0
  15. synth_ai/environments/examples/bandit/engine.py +294 -0
  16. synth_ai/environments/examples/bandit/environment.py +194 -0
  17. synth_ai/environments/examples/bandit/taskset.py +200 -0
  18. synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
  25. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
  26. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
  27. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
  28. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
  29. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
  30. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
  31. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
  32. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
  33. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
  34. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
  35. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
  36. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
  37. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
  38. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
  39. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
  40. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
  41. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
  42. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
  43. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
  44. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
  45. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
  46. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
  47. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
  48. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
  49. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
  50. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
  51. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
  52. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
  53. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
  54. synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
  55. synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
  56. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
  57. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
  58. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
  59. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
  60. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
  61. synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
  62. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
  63. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
  64. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
  65. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
  66. synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
  67. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
  68. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
  69. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
  70. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
  71. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
  72. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
  73. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
  74. synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
  75. synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
  76. synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
  77. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
  78. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
  79. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
  80. synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
  81. synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
  82. synth_ai/environments/examples/crafter_classic/environment.py +41 -2
  83. synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
  84. synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
  85. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
  86. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
  87. synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
  88. synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
  89. synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
  90. synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
  91. synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
  92. synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
  93. synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
  94. synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
  95. synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
  96. synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
  97. synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
  98. synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
  99. synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
  100. synth_ai/environments/examples/red/units/__init__.py +1 -0
  101. synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
  102. synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
  103. synth_ai/environments/service/app.py +8 -0
  104. synth_ai/install_sqld.sh +40 -0
  105. synth_ai-0.2.4.dev9.dist-info/METADATA +91 -0
  106. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/RECORD +110 -11
  107. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/entry_points.txt +1 -0
  108. synth_ai-0.2.4.dev8.dist-info/METADATA +0 -635
  109. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/WHEEL +0 -0
  110. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/licenses/LICENSE +0 -0
  111. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,858 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Comprehensive script to run Crafter rollouts for multiple models and compare their performance.
4
+ Updated to use tracing_v3 with async architecture.
5
+
6
+ Runs experiments for:
7
+ - gpt-4o-mini
8
+ - gpt-4.1-mini
9
+ - gpt-4.1-nano
10
+ - gemini-1.5-flash
11
+ - gemini-2.5-flash-lite
12
+ - qwen3/32b
13
+
14
+ Analyzes and compares:
15
+ - Invalid action rates
16
+ - Achievement frequencies by step
17
+ - Achievement counts across models
18
+ - Performance metrics
19
+ - Cost analysis
20
+ """
21
+
22
+ import argparse
23
+ import asyncio
24
+ import json
25
+ import logging
26
+ import os
27
+ import sys
28
+ import time
29
+ from collections import defaultdict
30
+ from datetime import datetime
31
+ from pathlib import Path
32
+ from typing import Any
33
+ from uuid import uuid4
34
+
35
+ import numpy as np
36
+ import pandas as pd
37
+ from tqdm import tqdm
38
+ from tqdm.asyncio import tqdm_asyncio as atqdm
39
+
40
+ # Disable httpx logging for cleaner output
41
+ logging.getLogger("httpx").setLevel(logging.WARNING)
42
+
43
+ # Add parent directory to path for imports
44
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent.parent))
45
+
46
+ # Disable v1 logging to see v3 tracing clearly
47
+ os.environ["LANGFUSE_ENABLED"] = "false"
48
+ os.environ["SYNTH_LOGGING"] = "false"
49
+
50
+ # Import enhanced LM with v3 tracing
51
+ from synth_ai.lm.core.main_v3 import LM
52
+ from synth_ai.tracing_v3.abstractions import (
53
+ EnvironmentEvent,
54
+ RuntimeEvent,
55
+ SessionEventMarkovBlanketMessage,
56
+ TimeRecord,
57
+ )
58
+ from synth_ai.tracing_v3.decorators import set_turn_number
59
+
60
+ # Import session tracer for v3 tracing
61
+ from synth_ai.tracing_v3.session_tracer import SessionTracer
62
+
63
+ # from synth_ai.tracing_v3.utils import create_experiment_context # Not needed
64
+ from synth_ai.tracing_v3.turso.manager import AsyncSQLTraceManager
65
+
66
+ # Import Crafter hooks
67
+ try:
68
+ from synth_ai.environments.examples.crafter_classic.trace_hooks_v3 import CRAFTER_HOOKS
69
+ print(f"āœ… Loaded {len(CRAFTER_HOOKS.hooks)} Crafter achievement hooks (Easy, Medium, Hard)")
70
+ except ImportError:
71
+ print("Warning: Could not import CRAFTER_HOOKS for v3")
72
+ from synth_ai.tracing_v3.hooks import HookManager
73
+ CRAFTER_HOOKS = HookManager()
74
+
75
+ import random
76
+
77
+ import httpx
78
+
79
+ # Global buckets for sessions
80
+ _SESSIONS: dict[str, tuple[str, object]] = {} # session_id -> (experiment_id, trace)
81
+
82
+ # Configuration
83
+ MODELS_TO_TEST = [
84
+ "gpt-4o-mini",
85
+ "gpt-4.1-mini",
86
+ ]
87
+
88
+ # Service URLs (modify these based on your setup)
89
+ CRAFTER_SERVICE_URL = "http://localhost:8901"
90
+
91
+ # Database configuration - uses the centralized config which matches serve.sh
92
+ from synth_ai.tracing_v3.db_config import get_default_db_config
93
+
94
+ db_config = get_default_db_config()
95
+ DATABASE_URL = db_config.database_url
96
+
97
+ # Retry configuration for HTTP requests
98
+ MAX_RETRIES = 3
99
+ BASE_DELAY = 0.1
100
+ MAX_DELAY = 2.0
101
+ HTTP_TIMEOUT = 30.0
102
+
103
+ class ExperimentConfig:
104
+ """Configuration for the multi-model experiment."""
105
+
106
+ def __init__(self):
107
+ self.num_episodes = 10 # Number of episodes per model
108
+ self.max_turns = 100 # Max turns per episode
109
+ self.difficulty = "easy"
110
+ self.save_traces = True
111
+ self.verbose = True
112
+ self.quiet = False # Default to verbose mode
113
+ self.enable_v3_tracing = True
114
+ self.v3_trace_dir = "./traces"
115
+ self.crafter_service_url = CRAFTER_SERVICE_URL
116
+ self.database_url = DATABASE_URL
117
+ self.base_seed = 1000 # Base seed for episode generation
118
+ self.turn_timeout = 30.0 # Timeout per turn in seconds
119
+ self.episode_timeout = 300.0 # Total timeout per episode in seconds
120
+
121
+
122
+ async def retry_http_request(client: httpx.AsyncClient, method: str, url: str, **kwargs) -> Any:
123
+ """Retry HTTP requests with exponential backoff and jitter."""
124
+ last_exception = None
125
+
126
+ for attempt in range(MAX_RETRIES):
127
+ try:
128
+ if attempt > 0:
129
+ delay = min(BASE_DELAY * (2 ** (attempt - 1)), MAX_DELAY)
130
+ jitter = random.uniform(0, 0.1 * delay)
131
+ total_delay = delay + jitter
132
+ await asyncio.sleep(total_delay)
133
+
134
+ response = await client.request(method, url, timeout=HTTP_TIMEOUT, **kwargs)
135
+
136
+ if response.status_code < 500:
137
+ return response
138
+
139
+ last_exception = Exception(f"HTTP {response.status_code}: {response.text}")
140
+
141
+ except httpx.ConnectError as e:
142
+ last_exception = Exception(f"Connection failed to {url}: {e}")
143
+ if attempt < MAX_RETRIES - 1:
144
+ await asyncio.sleep(1.0 * (2 ** attempt))
145
+ except httpx.ReadError as e:
146
+ last_exception = e
147
+ if attempt < MAX_RETRIES - 1:
148
+ read_error_delay = min(1.0 * (2 ** attempt), 5.0)
149
+ await asyncio.sleep(read_error_delay)
150
+ except Exception as e:
151
+ last_exception = e
152
+
153
+ print(f" āŒ HTTP request failed after {MAX_RETRIES} attempts: {method} {url}")
154
+ print(f" āŒ Error: {type(last_exception).__name__}: {str(last_exception)[:200]}")
155
+ raise last_exception
156
+
157
+
158
+ # Crafter action mapping
159
+ CRAFTER_ACTIONS = {
160
+ "noop": 0, "move_left": 1, "move_right": 2, "move_up": 3, "move_down": 4,
161
+ "do": 5, "sleep": 6, "place_stone": 7, "place_table": 8, "place_furnace": 9,
162
+ "place_plant": 10, "make_wood_pickaxe": 11, "make_stone_pickaxe": 12,
163
+ "make_iron_pickaxe": 13, "make_wood_sword": 14, "make_stone_sword": 15,
164
+ "make_iron_sword": 16, "eat_cow": 17, "eat_plant": 18
165
+ }
166
+
167
+ # Create reverse mapping for validation
168
+ INT_TO_ACTION_STRING = {v: k for k, v in CRAFTER_ACTIONS.items()}
169
+
170
+
171
+ def compress_observation_for_trace(obs: dict[str, Any]) -> str:
172
+ """Compress observation data for storage in traces."""
173
+ try:
174
+ return json.dumps({
175
+ "inv": {k: v for k, v in obs.get("inventory", {}).items() if v > 0},
176
+ "nearby": obs.get("nearby", []),
177
+ "hp": obs.get("status", {}).get("health", 0),
178
+ "food": obs.get("status", {}).get("food", 0),
179
+ "ach": sum(1 for v in obs.get("achievements_status", {}).values() if v)
180
+ }, separators=(',', ':'))
181
+ except Exception as e:
182
+ return f"{{\"error\": \"{str(e)}\"}}"
183
+
184
+
185
+ def create_message(content: str, message_type: str, system_id: str, turn: int) -> SessionEventMarkovBlanketMessage:
186
+ """Create a SessionEventMarkovBlanketMessage with metadata."""
187
+ return SessionEventMarkovBlanketMessage(
188
+ content=content,
189
+ message_type=message_type,
190
+ metadata={"system_id": system_id, "turn": turn},
191
+ time_record=TimeRecord(
192
+ event_time=time.time(),
193
+ message_time=turn
194
+ )
195
+ )
196
+
197
+
198
+ async def run_episode(config: ExperimentConfig,
199
+ model_name: str,
200
+ episode_num: int,
201
+ experiment_id: str) -> dict[str, Any]:
202
+ """Run a single episode with a specific model using v3 tracing."""
203
+ # Create a new session tracer for this episode
204
+ session_tracer = SessionTracer(hooks=CRAFTER_HOOKS, db_url=config.database_url)
205
+
206
+ # Start session with metadata
207
+ session_id = await session_tracer.start_session(
208
+ metadata={
209
+ "model": model_name,
210
+ "episode": episode_num,
211
+ "experiment_id": experiment_id,
212
+ "difficulty": config.difficulty
213
+ }
214
+ )
215
+
216
+ # Started tracing session (output disabled for clean UI)
217
+
218
+ # Store session in global bucket
219
+ _SESSIONS[session_id] = (experiment_id, session_tracer)
220
+
221
+ # Initialize LM with session tracer
222
+ lm = LM(
223
+ vendor="openai",
224
+ model=model_name,
225
+ temperature=0.1, # Low temperature for more consistent gameplay
226
+ session_tracer=session_tracer,
227
+ system_id=f"crafter_agent_{model_name}",
228
+ enable_v3_tracing=True
229
+ )
230
+
231
+ # Create HTTP client
232
+ async with httpx.AsyncClient() as client:
233
+ try:
234
+ # Initialize environment with consecutive seed
235
+ seed = config.base_seed + episode_num # Base seed + episode number for consecutive seeds
236
+ request_data = {"config": {"difficulty": config.difficulty, "seed": seed}}
237
+ init_response = await retry_http_request(
238
+ client, "POST", f"{config.crafter_service_url}/env/CrafterClassic/initialize",
239
+ json=request_data
240
+ )
241
+ init_data = init_response.json()
242
+
243
+ # Debug the response format (removed for clean output)
244
+
245
+ # Handle different possible response formats
246
+ if "instance_id" in init_data:
247
+ instance_id = init_data["instance_id"]
248
+ elif "env_id" in init_data:
249
+ instance_id = init_data["env_id"]
250
+ elif "id" in init_data:
251
+ instance_id = init_data["id"]
252
+ else:
253
+ # If none of the expected keys exist, print the response and raise a clear error
254
+ print(f"āŒ Unexpected response format from Crafter service: {init_data}")
255
+ raise KeyError(f"Could not find environment ID in response. Available keys: {list(init_data.keys())}")
256
+
257
+ # Get initial observation (from initialize response)
258
+ obs = init_data["observation"]
259
+
260
+ prev_obs = obs
261
+ done = False
262
+ invalid_actions = 0
263
+ total_actions = 0
264
+ episode_start_time = time.time()
265
+
266
+ for turn in range(config.max_turns):
267
+ if done:
268
+ break
269
+
270
+ # Check episode timeout
271
+ if time.time() - episode_start_time > config.episode_timeout:
272
+ print(f" ā° Episode {episode_num} timed out after {config.episode_timeout}s")
273
+ done = True
274
+ break
275
+
276
+ # Update progress bar
277
+ if hasattr(config, '_pbar'):
278
+ current_achievements = sum(1 for v in obs.get("achievements_status", {}).values() if v)
279
+ config._pbar.set_postfix({
280
+ f"ep{episode_num}": f"step {turn+1}/{config.max_turns}, ach: {current_achievements}"
281
+ })
282
+
283
+ set_turn_number(turn)
284
+
285
+ # Start timestep for this turn
286
+ await session_tracer.start_timestep(f"turn_{turn}")
287
+
288
+ # Prepare context for the agent
289
+ inventory_str = ", ".join([f"{k}: {v}" for k, v in obs.get("inventory", {}).items() if v > 0])
290
+ if not inventory_str:
291
+ inventory_str = "empty"
292
+
293
+ nearby_str = ", ".join(obs.get("nearby", []))
294
+ if not nearby_str:
295
+ nearby_str = "nothing"
296
+
297
+ status = obs.get("status", {})
298
+ health = status.get("health", 0)
299
+ hunger = status.get("food", 0)
300
+
301
+ # Get more detailed game state
302
+ position = obs.get("position", [0, 0])
303
+ achievements = obs.get("achievements_status", {})
304
+ unlocked = [name for name, status in achievements.items() if status]
305
+ achievements_str = ", ".join(unlocked) if unlocked else "none"
306
+
307
+ # Get semantic map if available
308
+ semantic_map = obs.get("semantic_map", None)
309
+ map_str = ""
310
+ if semantic_map is not None:
311
+ # Simple 5x5 view around player
312
+ try:
313
+ px, py = position
314
+ view_size = 5
315
+ half = view_size // 2
316
+ map_lines = []
317
+ for dy in range(-half, half + 1):
318
+ row = []
319
+ for dx in range(-half, half + 1):
320
+ x, y = px + dx, py + dy
321
+ if dx == 0 and dy == 0:
322
+ row.append("@") # Player
323
+ elif 0 <= x < len(semantic_map) and 0 <= y < len(semantic_map[0]):
324
+ cell = semantic_map[x][y]
325
+ # Map common items
326
+ if cell == 0:
327
+ row.append(".") # Empty/grass
328
+ elif cell == 1:
329
+ row.append("T") # Tree
330
+ elif cell == 2:
331
+ row.append("S") # Stone
332
+ elif cell == 3:
333
+ row.append("C") # Cow
334
+ elif cell == 4:
335
+ row.append("W") # Water
336
+ else:
337
+ row.append("?")
338
+ else:
339
+ row.append("#") # Out of bounds
340
+ map_lines.append(" ".join(row))
341
+ map_str = "\nMap (5x5 view, @ = you):\n" + "\n".join(map_lines)
342
+ except Exception:
343
+ map_str = "\nMap view unavailable"
344
+
345
+ # Create agent prompt
346
+ prompt = f"""Game State (Turn {turn}):
347
+ - Position: {position}
348
+ - Health: {health}/9
349
+ - Hunger: {hunger}/9
350
+ - Inventory: {inventory_str}
351
+ - Nearby objects: {nearby_str}
352
+ - Achievements unlocked: {achievements_str}
353
+ {map_str}
354
+
355
+ Choose your next actions based on what you see. Use the 'interact' tool with a list of action IDs.
356
+
357
+ Tips:
358
+ - Look at the map! T=tree (wood), S=stone, C=cow (food), W=water
359
+ - To collect resources: move to them (actions 1-4) then use action 5 (do)
360
+ - To craft: place table (8) first, then craft tools (11-16)
361
+ - If hungry and see cow (C), move to it and eat (17)
362
+
363
+ What actions do you want to take?"""
364
+
365
+ # Send observation as message
366
+ obs_msg = create_message(
367
+ f"Observation: {compress_observation_for_trace(obs)}",
368
+ "system",
369
+ f"crafter_env_{instance_id}",
370
+ turn
371
+ )
372
+ await session_tracer.record_message(
373
+ content=obs_msg.content,
374
+ message_type=obs_msg.message_type,
375
+ event_time=obs_msg.time_record.event_time,
376
+ message_time=obs_msg.time_record.message_time,
377
+ metadata=obs_msg.metadata
378
+ )
379
+
380
+ # Get action from LM with tools (with timeout)
381
+ turn_start_time = time.time()
382
+ try:
383
+ # Define the interact tool for Crafter
384
+ from pydantic import BaseModel, Field
385
+ from synth_ai.lm.tools.base import BaseTool
386
+
387
+ class InteractArgs(BaseModel):
388
+ actions: list[int] = Field(..., description="List of action IDs to execute")
389
+
390
+ interact_tool = BaseTool(
391
+ name="interact",
392
+ arguments=InteractArgs,
393
+ description="Execute actions in the Crafter game"
394
+ )
395
+
396
+ # Create system message that explains available actions
397
+ action_list = "\n".join([f"{action_id}: {action}" for action, action_id in CRAFTER_ACTIONS.items()])
398
+ system_message = f"""You are an agent playing Crafter, a 2D survival game. Your goal is to survive and unlock achievements.
399
+
400
+ You MUST use the 'interact' tool to execute actions. The tool takes a list of action IDs.
401
+
402
+ Action ID mapping:
403
+ {action_list}
404
+
405
+ Strategy tips:
406
+ - Start by collecting wood (move to trees and use action 5)
407
+ - Place a crafting table (action 8) to unlock crafting recipes
408
+ - Craft tools to collect resources more efficiently
409
+ - Eat when hungry, sleep when tired
410
+ - Explore to find different resources
411
+
412
+ IMPORTANT: Always use the 'interact' tool with a list of action IDs. For example: interact(actions=[2, 2, 5]) to move right twice and collect."""
413
+
414
+ # Get actions from LM using tools with timeout
415
+ try:
416
+ action_response = await asyncio.wait_for(
417
+ lm.respond_async(
418
+ system_message=system_message,
419
+ user_message=prompt,
420
+ tools=[interact_tool],
421
+ turn_number=turn
422
+ ),
423
+ timeout=config.turn_timeout
424
+ )
425
+ except asyncio.TimeoutError:
426
+ print(f" ā° Turn {turn} timed out for episode {episode_num} after {config.turn_timeout}s")
427
+ action_response = None
428
+ done = True
429
+ break
430
+
431
+ # Debug: print response (removed for clean output)
432
+
433
+ # Extract tool calls from response
434
+ if hasattr(action_response, 'tool_calls') and action_response.tool_calls:
435
+ tool_calls = action_response.tool_calls
436
+
437
+ # Process each tool call
438
+ for tool_call in tool_calls:
439
+ if tool_call.get('function', {}).get('name') == 'interact':
440
+ # Extract actions from the tool call
441
+ import json
442
+ args = json.loads(tool_call.get('function', {}).get('arguments', '{}'))
443
+ actions = args.get('actions', [])
444
+
445
+ if not actions:
446
+ # If no actions provided, use noop
447
+ actions = [0]
448
+
449
+ # Execute each action separately
450
+ for action_id in actions:
451
+ total_actions += 1
452
+
453
+ # Validate action ID
454
+ if action_id not in INT_TO_ACTION_STRING:
455
+ # Invalid action logging removed for clean output
456
+ action_id = 0
457
+ invalid_actions += 1
458
+
459
+ # Send action to Crafter service with timeout
460
+ try:
461
+ step_response = await asyncio.wait_for(
462
+ retry_http_request(
463
+ client, "POST", f"{config.crafter_service_url}/env/CrafterClassic/step",
464
+ json={
465
+ "env_id": instance_id,
466
+ "action": {
467
+ "tool_calls": [
468
+ {"tool": "interact", "args": {"action": action_id}}
469
+ ]
470
+ }
471
+ }
472
+ ),
473
+ timeout=5.0 # 5 second timeout for individual action
474
+ )
475
+ except asyncio.TimeoutError:
476
+ print(f" ā° Action execution timed out in episode {episode_num}")
477
+ done = True
478
+ break
479
+
480
+ if step_response.status_code != 200:
481
+ print(f" āŒ Step failed: {step_response.status_code} - {step_response.text}")
482
+ done = True
483
+ break
484
+
485
+ step_data = step_response.json()
486
+
487
+ # Extract data from response
488
+ new_obs = step_data["observation"]
489
+ reward = step_data["reward"]
490
+ done = step_data["done"]
491
+
492
+ # Record runtime event for action
493
+ action_name = INT_TO_ACTION_STRING.get(action_id, "unknown")
494
+ runtime_event = RuntimeEvent(
495
+ system_instance_id=f"crafter_env_{instance_id}",
496
+ time_record=TimeRecord(
497
+ event_time=time.time(),
498
+ message_time=turn
499
+ ),
500
+ actions=[action_id],
501
+ metadata={
502
+ "action_name": action_name,
503
+ "valid": action_name != "noop" or invalid_actions == 0
504
+ }
505
+ )
506
+ await session_tracer.record_event(runtime_event)
507
+
508
+ # Record environment event
509
+ env_event = EnvironmentEvent(
510
+ system_instance_id=f"crafter_env_{instance_id}",
511
+ time_record=TimeRecord(
512
+ event_time=time.time(),
513
+ message_time=turn
514
+ ),
515
+ reward=reward,
516
+ terminated=done,
517
+ system_state_before={"observation": prev_obs},
518
+ system_state_after={"observation": new_obs, "public_state": {"achievements_status": new_obs.get("achievements_status", {})}}
519
+ )
520
+ await session_tracer.record_event(env_event)
521
+
522
+ # Update for next turn
523
+ prev_obs = obs
524
+ obs = new_obs
525
+
526
+ if done:
527
+ break
528
+
529
+ # Update progress bar after each action
530
+ if hasattr(config, '_pbar'):
531
+ config._pbar.update(1)
532
+ else:
533
+ # No tool calls provided, use noop
534
+ action_id = 0
535
+ total_actions += 1
536
+ invalid_actions += 1
537
+
538
+ # Send noop action with timeout
539
+ try:
540
+ step_response = await asyncio.wait_for(
541
+ retry_http_request(
542
+ client, "POST", f"{config.crafter_service_url}/env/CrafterClassic/step",
543
+ json={
544
+ "env_id": instance_id,
545
+ "action": {
546
+ "tool_calls": [
547
+ {"tool": "interact", "args": {"action": action_id}}
548
+ ]
549
+ }
550
+ }
551
+ ),
552
+ timeout=5.0 # 5 second timeout
553
+ )
554
+ except asyncio.TimeoutError:
555
+ print(f" ā° Noop action timed out in episode {episode_num}")
556
+ done = True
557
+ break
558
+
559
+ if step_response.status_code != 200:
560
+ print(f" āŒ Step failed: {step_response.status_code} - {step_response.text}")
561
+ done = True
562
+ else:
563
+ step_data = step_response.json()
564
+ new_obs = step_data["observation"]
565
+ reward = step_data["reward"]
566
+ done = step_data["done"]
567
+
568
+ # Update observation
569
+ prev_obs = obs
570
+ obs = new_obs
571
+
572
+ # End timestep
573
+ await session_tracer.end_timestep(f"turn_{turn}")
574
+
575
+ except Exception as e:
576
+ print(f" āŒ Environment step error: {e}")
577
+ done = True
578
+
579
+ # Update progress bar for remaining steps if episode ended early
580
+ if hasattr(config, '_pbar') and turn < config.max_turns - 1:
581
+ remaining_steps = config.max_turns - turn - 1
582
+ config._pbar.update(remaining_steps)
583
+
584
+ # Calculate invalid action rate
585
+ invalid_rate = invalid_actions / total_actions if total_actions > 0 else 0
586
+
587
+ # Calculate achievements
588
+ final_achievements = obs.get("achievements_status", {})
589
+ total_achievements = sum(1 for v in final_achievements.values() if v)
590
+
591
+ # Terminate environment
592
+ try:
593
+ await retry_http_request(
594
+ client, "POST", f"{config.crafter_service_url}/env/CrafterClassic/terminate",
595
+ json={"env_id": instance_id}
596
+ )
597
+ except Exception as e:
598
+ print(f" āš ļø Failed to terminate environment: {e}")
599
+
600
+ # End session
601
+ await session_tracer.end_session(save=config.save_traces)
602
+ # Close the tracer for this episode
603
+ await session_tracer.close()
604
+
605
+ return {
606
+ "model": model_name,
607
+ "episode": episode_num,
608
+ "total_achievements": total_achievements,
609
+ "achievements": final_achievements,
610
+ "invalid_action_rate": invalid_rate,
611
+ "total_actions": total_actions,
612
+ "invalid_actions": invalid_actions,
613
+ "session_id": session_id
614
+ }
615
+
616
+ except Exception as e:
617
+ print(f" āŒ Episode failed: {e}")
618
+ import traceback
619
+ traceback.print_exc()
620
+
621
+ # End session even if failed
622
+ await session_tracer.end_session(save=config.save_traces)
623
+ # Close the tracer for this episode
624
+ await session_tracer.close()
625
+
626
+ return {
627
+ "model": model_name,
628
+ "episode": episode_num,
629
+ "total_achievements": 0,
630
+ "achievements": {},
631
+ "invalid_action_rate": 1.0,
632
+ "total_actions": 0,
633
+ "invalid_actions": 0,
634
+ "session_id": session_id,
635
+ "error": str(e)
636
+ }
637
+
638
+
639
+ async def run_model_experiment(config: ExperimentConfig, model_name: str, experiment_id: str) -> list[dict[str, Any]]:
640
+ """Run multiple episodes for a single model in parallel."""
641
+ print(f"\nšŸš€ Running {config.num_episodes} episodes for {model_name} in parallel...\n")
642
+
643
+ # Create a progress bar for all steps across all episodes
644
+ total_steps = config.num_episodes * config.max_turns
645
+ pbar = atqdm(total=total_steps, desc=f"{model_name}", unit="steps", leave=True)
646
+ config._pbar = pbar # Store in config so episodes can update it
647
+
648
+ try:
649
+ # Create tasks for all episodes (each will create its own tracer)
650
+ tasks = []
651
+ for i in range(config.num_episodes):
652
+ task = run_episode(config, model_name, i, experiment_id)
653
+ tasks.append(task)
654
+
655
+ # Run all episodes in parallel
656
+ results = await asyncio.gather(*tasks)
657
+
658
+ # Calculate summary stats
659
+ successful_results = [r for r in results if "error" not in r]
660
+ if successful_results:
661
+ avg_achievements = sum(r["total_achievements"] for r in successful_results) / len(successful_results)
662
+ avg_invalid_rate = sum(r["invalid_action_rate"] for r in successful_results) / len(successful_results)
663
+ pbar.set_postfix({
664
+ "avg_achievements": f"{avg_achievements:.1f}",
665
+ "avg_invalid_rate": f"{avg_invalid_rate:.1%}",
666
+ "success_rate": f"{len(successful_results)}/{len(results)}"
667
+ })
668
+ finally:
669
+ pbar.close()
670
+
671
+ return results
672
+
673
+
674
+ async def analyze_results(config: ExperimentConfig, all_results: dict[str, list[dict[str, Any]]]):
675
+ """Analyze results across all models using v3 database."""
676
+ print("\nšŸ“Š Analysis Results:")
677
+ print("=" * 80)
678
+
679
+ # Initialize database manager
680
+ db_manager = AsyncSQLTraceManager(config.database_url)
681
+ await db_manager.initialize()
682
+
683
+ try:
684
+ # Basic statistics by model
685
+ model_stats = {}
686
+ for model, results in all_results.items():
687
+ valid_results = [r for r in results if "error" not in r]
688
+ if valid_results:
689
+ achievements = [r["total_achievements"] for r in valid_results]
690
+ invalid_rates = [r["invalid_action_rate"] for r in valid_results]
691
+
692
+ model_stats[model] = {
693
+ "avg_achievements": np.mean(achievements),
694
+ "std_achievements": np.std(achievements),
695
+ "max_achievements": max(achievements),
696
+ "avg_invalid_rate": np.mean(invalid_rates),
697
+ "success_rate": len(valid_results) / len(results)
698
+ }
699
+
700
+ # Print model comparison
701
+ print("\nšŸ“ˆ Model Performance Summary:")
702
+ print(f"{'Model':<20} {'Avg Achievements':<18} {'Max Achievements':<18} {'Invalid Rate':<15} {'Success Rate':<15}")
703
+ print("-" * 86)
704
+
705
+ for model, stats in sorted(model_stats.items(), key=lambda x: x[1]["avg_achievements"], reverse=True):
706
+ print(f"{model:<20} {stats['avg_achievements']:>6.2f} ± {stats['std_achievements']:>4.2f} "
707
+ f"{stats['max_achievements']:>16} {stats['avg_invalid_rate']:>12.2%} {stats['success_rate']:>12.2%}")
708
+
709
+ # Achievement frequency analysis
710
+ print("\nšŸ† Achievement Frequencies:")
711
+ achievement_counts = defaultdict(lambda: defaultdict(int))
712
+
713
+ for model, results in all_results.items():
714
+ for result in results:
715
+ if "error" not in result:
716
+ for achievement, unlocked in result["achievements"].items():
717
+ if unlocked:
718
+ achievement_counts[model][achievement] += 1
719
+
720
+ # Get all unique achievements
721
+ all_achievements = set()
722
+ for model_achievements in achievement_counts.values():
723
+ all_achievements.update(model_achievements.keys())
724
+
725
+ # Print achievement table
726
+ if all_achievements:
727
+ print(f"\n{'Achievement':<25} " + " ".join(f"{model[:8]:>10}" for model in sorted(all_results.keys())))
728
+ print("-" * (25 + 11 * len(all_results)))
729
+
730
+ for achievement in sorted(all_achievements):
731
+ row = f"{achievement:<25}"
732
+ for model in sorted(all_results.keys()):
733
+ count = achievement_counts[model].get(achievement, 0)
734
+ total = len([r for r in all_results[model] if "error" not in r])
735
+ pct = (count / total * 100) if total > 0 else 0
736
+ row += f" {count:>3}/{total:<3} ({pct:>3.0f}%)"
737
+ print(row)
738
+
739
+ # Query model usage from database - filter to only show models used in this experiment
740
+ print("\nšŸ’° Model Usage Statistics from Current Experiment:")
741
+ model_usage_df = await db_manager.get_model_usage()
742
+
743
+ if model_usage_df is not None and not model_usage_df.empty:
744
+ # Filter to only show models from this experiment
745
+ experiment_models = set(all_results.keys())
746
+ filtered_df = model_usage_df[model_usage_df['model_name'].isin(experiment_models)]
747
+
748
+ if not filtered_df.empty:
749
+ # Format model usage statistics as table
750
+ print(f"{'Model':<20} {'Provider':<10} {'Usage Count':<12} {'Avg Latency (ms)':<18} {'Total Cost':<12}")
751
+ print("-" * 72)
752
+ for _, row in filtered_df.iterrows():
753
+ avg_latency = row['avg_latency_ms']
754
+ if pd.notna(avg_latency):
755
+ print(f"{row['model_name']:<20} {row['provider'] or 'N/A':<10} {row['usage_count']:<12} "
756
+ f"{avg_latency:<18.2f} ${row['total_cost_usd']:<11.4f}")
757
+ else:
758
+ print(f"{row['model_name']:<20} {row['provider'] or 'N/A':<10} {row['usage_count']:<12} "
759
+ f"{'N/A':<18} ${row['total_cost_usd']:<11.4f}")
760
+
761
+ # Export detailed results
762
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
763
+ results_file = f"crafter_experiment_results_{timestamp}.json"
764
+
765
+ with open(results_file, "w") as f:
766
+ json.dump({
767
+ "config": {
768
+ "num_episodes": config.num_episodes,
769
+ "max_turns": config.max_turns,
770
+ "difficulty": config.difficulty,
771
+ "models": list(all_results.keys())
772
+ },
773
+ "results": all_results,
774
+ "statistics": model_stats,
775
+ "timestamp": timestamp
776
+ }, f, indent=2)
777
+
778
+ print(f"\nšŸ’¾ Detailed results saved to: {results_file}")
779
+
780
+ finally:
781
+ await db_manager.close()
782
+
783
+
784
+ async def main():
785
+ """Main entry point for the experiment."""
786
+ parser = argparse.ArgumentParser(description="Run Crafter experiments with multiple models")
787
+ parser.add_argument("--episodes", type=int, default=5, help="Number of episodes per model")
788
+ parser.add_argument("--max-turns", type=int, default=100, help="Maximum turns per episode")
789
+ parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy", help="Game difficulty")
790
+ parser.add_argument("--models", nargs="+", default=MODELS_TO_TEST, help="Models to test")
791
+ parser.add_argument("--no-save", action="store_true", help="Don't save traces to database")
792
+ parser.add_argument("--quiet", action="store_true", help="Reduce output verbosity")
793
+ parser.add_argument("--db-url", default=DATABASE_URL, help="Database URL for tracing")
794
+ parser.add_argument("--base-seed", type=int, default=1000, help="Base seed for episodes (episodes use base_seed+episode_num)")
795
+ parser.add_argument("--turn-timeout", type=float, default=30.0, help="Timeout per turn in seconds")
796
+ parser.add_argument("--episode-timeout", type=float, default=300.0, help="Total timeout per episode in seconds")
797
+
798
+ args = parser.parse_args()
799
+
800
+ # Create configuration
801
+ config = ExperimentConfig()
802
+ config.num_episodes = args.episodes
803
+ config.max_turns = args.max_turns
804
+ config.difficulty = args.difficulty
805
+ config.save_traces = not args.no_save
806
+ config.verbose = not args.quiet
807
+ config.quiet = args.quiet
808
+ config.database_url = args.db_url
809
+ config.base_seed = args.base_seed
810
+ config.turn_timeout = args.turn_timeout
811
+ config.episode_timeout = args.episode_timeout
812
+
813
+ # Generate experiment ID
814
+ experiment_id = f"crafter_multi_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
815
+
816
+ print("šŸŽ® Crafter Multi-Model Experiment")
817
+ print("=" * 50)
818
+ print(f"Experiment ID: {experiment_id}")
819
+ print(f"Models: {', '.join(args.models)}")
820
+ print(f"Episodes per model: {config.num_episodes}")
821
+ print(f"Max turns per episode: {config.max_turns}")
822
+ print(f"Difficulty: {config.difficulty}")
823
+ print(f"Seeds: {config.base_seed} to {config.base_seed + config.num_episodes - 1}")
824
+ print(f"Turn timeout: {config.turn_timeout}s")
825
+ print(f"Episode timeout: {config.episode_timeout}s")
826
+ print(f"Save traces: {config.save_traces}")
827
+ print(f"Database URL: {config.database_url}")
828
+ print("=" * 50)
829
+
830
+ # Check Crafter service
831
+ try:
832
+ async with httpx.AsyncClient() as client:
833
+ response = await client.get(f"{config.crafter_service_url}/health", timeout=5.0)
834
+ if response.status_code != 200:
835
+ print(f"āŒ Crafter service not healthy at {config.crafter_service_url}")
836
+ return
837
+ except Exception as e:
838
+ print(f"āŒ Cannot connect to Crafter service at {config.crafter_service_url}: {e}")
839
+ print("Please ensure the Crafter service is running.")
840
+ return
841
+
842
+ print("āœ… Crafter service is running")
843
+
844
+ # Run experiments for each model
845
+ all_results = {}
846
+
847
+ for model in args.models:
848
+ results = await run_model_experiment(config, model, experiment_id)
849
+ all_results[model] = results
850
+
851
+ # Analyze and compare results
852
+ await analyze_results(config, all_results)
853
+
854
+ print("\nāœ… Experiment complete!")
855
+
856
+
857
+ if __name__ == "__main__":
858
+ asyncio.run(main())