synth-ai 0.2.4.dev8__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (112) hide show
  1. synth_ai/__init__.py +1 -1
  2. synth_ai/cli/__init__.py +6 -0
  3. synth_ai/cli/demo.py +68 -9
  4. synth_ai/cli/rl_demo.py +137 -0
  5. synth_ai/cli/root.py +65 -0
  6. synth_ai/demos/core/__init__.py +1 -0
  7. synth_ai/demos/core/cli.py +685 -0
  8. synth_ai/demos/demo_task_apps/__init__.py +1 -0
  9. synth_ai/demos/demo_task_apps/core.py +374 -0
  10. synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
  11. synth_ai/demos/demo_task_apps/math/app.py +37 -0
  12. synth_ai/demos/demo_task_apps/math/config.toml +44 -0
  13. synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
  14. synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
  15. synth_ai/environments/examples/bandit/__init__.py +33 -0
  16. synth_ai/environments/examples/bandit/engine.py +294 -0
  17. synth_ai/environments/examples/bandit/environment.py +194 -0
  18. synth_ai/environments/examples/bandit/taskset.py +200 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
  25. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
  26. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
  27. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
  28. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
  29. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
  30. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
  31. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
  32. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
  33. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
  34. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
  35. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
  36. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
  37. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
  38. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
  39. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
  40. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
  41. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
  42. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
  43. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
  44. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
  45. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
  46. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
  47. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
  48. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
  49. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
  50. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
  51. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
  52. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
  53. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
  54. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
  55. synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
  56. synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
  57. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
  58. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
  59. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
  60. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
  61. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
  62. synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
  63. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
  64. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
  65. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
  66. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
  67. synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
  68. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
  69. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
  70. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
  71. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
  72. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
  73. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
  74. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
  75. synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
  76. synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
  77. synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
  78. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
  79. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
  80. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
  81. synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
  82. synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
  83. synth_ai/environments/examples/crafter_classic/environment.py +41 -2
  84. synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
  85. synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
  86. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
  87. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
  88. synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
  89. synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
  90. synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
  91. synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
  92. synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
  93. synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
  94. synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
  95. synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
  96. synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
  97. synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
  98. synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
  99. synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
  100. synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
  101. synth_ai/environments/examples/red/units/__init__.py +1 -0
  102. synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
  103. synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
  104. synth_ai/environments/service/app.py +8 -0
  105. synth_ai/install_sqld.sh +40 -0
  106. synth_ai-0.2.5.dist-info/METADATA +106 -0
  107. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/RECORD +111 -12
  108. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/entry_points.txt +1 -0
  109. synth_ai-0.2.4.dev8.dist-info/METADATA +0 -635
  110. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/WHEEL +0 -0
  111. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/licenses/LICENSE +0 -0
  112. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,655 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Run Crafter rollouts with Qwen models and display results in a table format
4
+ """
5
+
6
+ import asyncio
7
+ import json
8
+ import uuid
9
+ import argparse
10
+ import logging
11
+ import time
12
+ from datetime import datetime
13
+ from typing import Dict, Any, Optional, List, Tuple
14
+ from pydantic import BaseModel
15
+ import httpx
16
+ import os
17
+ from pathlib import Path
18
+ import numpy as np
19
+ from rich.console import Console
20
+ from rich.table import Table
21
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeRemainingColumn
22
+ from rich.live import Live
23
+ from rich.layout import Layout
24
+ from rich.panel import Panel
25
+ from collections import defaultdict
26
+
27
+ # Disable Langfuse
28
+ os.environ["LANGFUSE_ENABLED"] = "false"
29
+ os.environ["LANGFUSE_PUBLIC_KEY"] = "dummy"
30
+ os.environ["LANGFUSE_SECRET_KEY"] = "dummy"
31
+
32
+ # Import Crafter hooks
33
+ try:
34
+ from synth_ai.environments.examples.crafter_classic.trace_hooks import CRAFTER_HOOKS
35
+ except ImportError:
36
+ CRAFTER_HOOKS = []
37
+
38
+ # Service configuration
39
+ MODAL_BASE_URL = "https://synth-laboratories--unified-ft-service-fastapi-app.modal.run"
40
+ MODAL_API_KEY = os.environ.get("MODAL_API_KEY", "sk-test-11111111111111111111111111111111")
41
+
42
+ # Model size routing based on Modal service configuration
43
+ MODEL_SIZE_ROUTING = {
44
+ "0.5B": "small",
45
+ "1.5B": "small",
46
+ "3B": "small",
47
+ "7B": "medium",
48
+ "14B": "medium",
49
+ "32B": "large32",
50
+ "72B": "large72"
51
+ }
52
+
53
+ def get_model_size_category(model_name: str) -> str:
54
+ """Get the size category for routing."""
55
+ for size, category in MODEL_SIZE_ROUTING.items():
56
+ if f"-{size}-" in model_name or model_name.endswith(f"-{size}"):
57
+ return category
58
+ return "medium" # Default to medium
59
+
60
+ # HTTP retry configuration
61
+ MAX_RETRIES = 3
62
+ BASE_DELAY = 0.1
63
+ MAX_DELAY = 2.0
64
+ HTTP_TIMEOUT = 120.0
65
+
66
+ console = Console()
67
+
68
+ class RolloutConfig(BaseModel):
69
+ """Configuration for rollout evaluation."""
70
+ # Model settings
71
+ model_name: str = "Qwen/Qwen2.5-7B-Instruct"
72
+ temperature: float = 0.7
73
+ max_tokens: int = 512
74
+
75
+ # Evaluation settings
76
+ num_episodes: int = 10
77
+ max_steps_per_episode: int = 100
78
+ difficulty: str = "easy"
79
+ seed: Optional[int] = None
80
+
81
+ # Service settings
82
+ crafter_url: str = "http://localhost:8901"
83
+ llm_base_url: str = MODAL_BASE_URL
84
+ llm_api_key: str = MODAL_API_KEY
85
+
86
+ # Display settings
87
+ show_live_progress: bool = True
88
+ save_results: bool = True
89
+ output_file: Optional[str] = None
90
+
91
+
92
+ class EpisodeStats:
93
+ """Track statistics for an episode."""
94
+ def __init__(self, episode_id: str):
95
+ self.episode_id = episode_id
96
+ self.steps = 0
97
+ self.total_reward = 0.0
98
+ self.achievements = []
99
+ self.final_health = 0
100
+ self.final_hunger = 0
101
+ self.final_thirst = 0
102
+ self.resources_collected = defaultdict(int)
103
+ self.actions_taken = defaultdict(int)
104
+ self.start_time = time.time()
105
+ self.end_time = None
106
+ self.termination_reason = None
107
+ self.llm_response_times = []
108
+
109
+ def duration(self) -> float:
110
+ if self.end_time:
111
+ return self.end_time - self.start_time
112
+ return time.time() - self.start_time
113
+
114
+ def avg_response_time(self) -> float:
115
+ if self.llm_response_times:
116
+ return np.mean(self.llm_response_times)
117
+ return 0.0
118
+
119
+
120
+ async def retry_http_request(client: httpx.AsyncClient, method: str, url: str, **kwargs) -> Any:
121
+ """Retry HTTP requests with exponential backoff."""
122
+ for attempt in range(MAX_RETRIES):
123
+ try:
124
+ if attempt > 0:
125
+ await asyncio.sleep(BASE_DELAY * (2 ** (attempt - 1)))
126
+
127
+ response = await client.request(method, url, timeout=HTTP_TIMEOUT, **kwargs)
128
+
129
+ if response.status_code < 500:
130
+ return response
131
+
132
+ except Exception as e:
133
+ if attempt == MAX_RETRIES - 1:
134
+ raise e
135
+
136
+ raise Exception(f"Failed after {MAX_RETRIES} attempts")
137
+
138
+
139
+ async def warmup_model(config: RolloutConfig, max_attempts: int = 30) -> bool:
140
+ """Warmup the model by polling until it's ready."""
141
+ console.print(f"[yellow]Warming up {config.model_name}...[/yellow]")
142
+
143
+ # First try the warmup endpoint if available
144
+ async with httpx.AsyncClient() as client:
145
+ headers = {
146
+ "Authorization": f"Bearer {config.llm_api_key}",
147
+ "Content-Type": "application/json"
148
+ }
149
+
150
+ # Try warmup endpoint
151
+ try:
152
+ warmup_url = f"{config.llm_base_url}/warmup/{config.model_name}"
153
+ response = await client.post(warmup_url, headers=headers, timeout=30.0)
154
+ if response.status_code == 200:
155
+ console.print("[green]✓ Model warmup endpoint called[/green]")
156
+ except:
157
+ pass # Warmup endpoint might not exist
158
+
159
+ # Now poll with actual inference requests
160
+ test_messages = [
161
+ {"role": "user", "content": "Say 'ready' if you're loaded."}
162
+ ]
163
+
164
+ for attempt in range(max_attempts):
165
+ try:
166
+ start_time = time.time()
167
+ response = await client.post(
168
+ f"{config.llm_base_url}/v1/chat/completions",
169
+ headers=headers,
170
+ json={
171
+ "model": config.model_name,
172
+ "messages": test_messages,
173
+ "temperature": 0.1,
174
+ "max_tokens": 10,
175
+ },
176
+ timeout=120.0
177
+ )
178
+ elapsed = time.time() - start_time
179
+
180
+ if response.status_code == 200:
181
+ data = response.json()
182
+ if "choices" in data and data["choices"]:
183
+ console.print(f"[green]✓ Model ready! (response time: {elapsed:.1f}s)[/green]")
184
+ return True
185
+
186
+ # If we get here, model is still loading
187
+ if elapsed > 10:
188
+ console.print(f"[yellow]Model is loading... attempt {attempt + 1}/{max_attempts} (took {elapsed:.1f}s)[/yellow]")
189
+
190
+ except httpx.TimeoutException:
191
+ console.print(f"[yellow]Timeout waiting for model... attempt {attempt + 1}/{max_attempts}[/yellow]")
192
+ except Exception as e:
193
+ console.print(f"[yellow]Error during warmup: {str(e)[:100]}[/yellow]")
194
+
195
+ # Wait before retrying
196
+ await asyncio.sleep(5)
197
+
198
+ console.print(f"[red]Failed to warmup model after {max_attempts} attempts[/red]")
199
+ return False
200
+
201
+
202
+ async def call_llm(messages: List[Dict[str, str]], config: RolloutConfig) -> Tuple[str, float]:
203
+ """Call LLM and return response with timing."""
204
+ async with httpx.AsyncClient() as client:
205
+ headers = {
206
+ "Authorization": f"Bearer {config.llm_api_key}",
207
+ "Content-Type": "application/json"
208
+ }
209
+
210
+ payload = {
211
+ "model": config.model_name,
212
+ "messages": messages,
213
+ "temperature": config.temperature,
214
+ "max_tokens": config.max_tokens,
215
+ }
216
+
217
+
218
+ start_time = time.time()
219
+ response = await retry_http_request(
220
+ client,
221
+ "POST",
222
+ f"{config.llm_base_url}/v1/chat/completions",
223
+ headers=headers,
224
+ json=payload
225
+ )
226
+ elapsed = time.time() - start_time
227
+
228
+ if response.status_code != 200:
229
+ raise Exception(f"LLM API error: {response.status_code} - {response.text}")
230
+
231
+ data = response.json()
232
+ return data["choices"][0]["message"]["content"], elapsed
233
+
234
+
235
+ def format_observation(obs: Dict[str, Any]) -> str:
236
+ """Format observation into a concise prompt."""
237
+ inv = obs.get("inventory", {})
238
+ health = obs.get("health", 10)
239
+ hunger = obs.get("food", 10)
240
+ thirst = obs.get("drink", 10)
241
+
242
+ # Get nearby objects in a 5x5 view
243
+ semantic_map = obs.get("semantic_map")
244
+ if semantic_map is not None:
245
+ # Simple 5x5 view around player
246
+ view = []
247
+ for dy in range(-2, 3):
248
+ row = []
249
+ for dx in range(-2, 3):
250
+ if dx == 0 and dy == 0:
251
+ row.append("P")
252
+ else:
253
+ # Simplified - just show if something is there
254
+ row.append(".")
255
+ view.append(" ".join(row))
256
+ map_str = "\n".join(view)
257
+ else:
258
+ map_str = "Map unavailable"
259
+
260
+ # Format inventory (only non-zero items)
261
+ inv_items = [f"{k}:{v}" for k, v in inv.items()
262
+ if v > 0 and k not in ["health", "food", "drink", "energy"]]
263
+ inv_str = ", ".join(inv_items) if inv_items else "empty"
264
+
265
+ return f"""Status: Health={health}/10, Hunger={hunger}/10, Thirst={thirst}/10
266
+ Inventory: {inv_str}
267
+ Nearby (5x5, P=player):
268
+ {map_str}
269
+
270
+ What action should you take? Choose one:
271
+ move_left, move_right, move_up, move_down, do, sleep, place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword
272
+
273
+ Action:"""
274
+
275
+
276
+ async def run_episode(
277
+ episode_id: str,
278
+ config: RolloutConfig,
279
+ progress: Optional[Any] = None
280
+ ) -> EpisodeStats:
281
+ """Run a single episode."""
282
+ stats = EpisodeStats(episode_id)
283
+
284
+ async with httpx.AsyncClient() as client:
285
+ # Create environment
286
+ create_resp = await retry_http_request(
287
+ client,
288
+ "POST",
289
+ f"{config.crafter_url}/CrafterClassic/create",
290
+ json={
291
+ "instance_id": episode_id,
292
+ "render_mode": "rgb_array",
293
+ "difficulty": config.difficulty,
294
+ "seed": config.seed
295
+ }
296
+ )
297
+
298
+ env_data = create_resp.json()
299
+ instance_id = env_data["instance_id"]
300
+
301
+ # Reset environment
302
+ reset_resp = await retry_http_request(
303
+ client,
304
+ "POST",
305
+ f"{config.crafter_url}/CrafterClassic/{instance_id}/reset",
306
+ json={}
307
+ )
308
+
309
+ obs_data = reset_resp.json().get("private", {})
310
+
311
+ # System message for the agent
312
+ messages = [{
313
+ "role": "system",
314
+ "content": "You are playing Crafter, a survival game. Your goals are to: 1) Stay alive by maintaining health/hunger/thirst, 2) Gather resources (wood, stone, etc), 3) Craft tools and items. Respond with only the action name."
315
+ }]
316
+
317
+ # Action mapping
318
+ action_map = {
319
+ 'noop': 0, 'move_left': 1, 'move_right': 2, 'move_up': 3,
320
+ 'move_down': 4, 'do': 5, 'sleep': 6, 'place_stone': 7,
321
+ 'place_table': 8, 'place_furnace': 9, 'place_plant': 10,
322
+ 'make_wood_pickaxe': 11, 'make_stone_pickaxe': 12,
323
+ 'make_iron_pickaxe': 13, 'make_wood_sword': 14,
324
+ 'make_stone_sword': 15, 'make_iron_sword': 16
325
+ }
326
+
327
+ # Run episode
328
+ for step in range(config.max_steps_per_episode):
329
+ # Create prompt
330
+ prompt = format_observation(obs_data)
331
+ messages.append({"role": "user", "content": prompt})
332
+
333
+ # Get LLM response
334
+ try:
335
+ response_text, response_time = await call_llm(messages, config)
336
+ stats.llm_response_times.append(response_time)
337
+
338
+ # Parse action
339
+ action = None
340
+ response_lower = response_text.strip().lower()
341
+ for action_name in action_map.keys():
342
+ if action_name in response_lower:
343
+ action = action_name
344
+ break
345
+
346
+ if not action:
347
+ action = "do" # Default
348
+
349
+ stats.actions_taken[action] += 1
350
+ action_idx = action_map[action]
351
+
352
+ # Take action
353
+ step_payload = {
354
+ "env_id": instance_id,
355
+ "request_id": f"{episode_id}_step_{step}",
356
+ "action": {
357
+ "tool_calls": [{
358
+ "tool": "interact",
359
+ "args": {"action": action_idx}
360
+ }]
361
+ }
362
+ }
363
+
364
+ step_resp = await retry_http_request(
365
+ client,
366
+ "POST",
367
+ f"{config.crafter_url}/env/CrafterClassic/step",
368
+ json=step_payload
369
+ )
370
+
371
+ step_data = step_resp.json()
372
+ new_obs = step_data.get("private", {})
373
+ reward = step_data.get("reward", 0) or 0
374
+ done = step_data.get("done", False)
375
+
376
+ stats.total_reward += reward
377
+ stats.steps += 1
378
+
379
+ # Track achievements
380
+ for ach, status in new_obs.get("achievements_status", {}).items():
381
+ if status and ach not in stats.achievements:
382
+ stats.achievements.append(ach)
383
+
384
+ # Track resources
385
+ inv = new_obs.get("inventory", {})
386
+ for item, count in inv.items():
387
+ if item not in ["health", "food", "drink", "energy"] and count > 0:
388
+ stats.resources_collected[item] = max(stats.resources_collected[item], count)
389
+
390
+ # Update final stats
391
+ stats.final_health = inv.get("health", 0)
392
+ stats.final_hunger = inv.get("food", 0)
393
+ stats.final_thirst = inv.get("drink", 0)
394
+
395
+ # Keep conversation short
396
+ messages = messages[-4:] # Keep only recent context
397
+ messages.append({"role": "assistant", "content": action})
398
+
399
+ if done:
400
+ stats.termination_reason = step_data.get("termination_reason", "completed")
401
+ break
402
+
403
+ obs_data = new_obs
404
+
405
+ if progress:
406
+ progress()
407
+
408
+ except Exception as e:
409
+ stats.termination_reason = f"error: {str(e)}"
410
+ break
411
+
412
+ # Clean up
413
+ try:
414
+ await client.post(f"{config.crafter_url}/CrafterClassic/{instance_id}/terminate")
415
+ except:
416
+ pass
417
+
418
+ stats.end_time = time.time()
419
+ return stats
420
+
421
+
422
+ def create_results_table(all_stats: List[EpisodeStats]) -> Table:
423
+ """Create a rich table with results."""
424
+ table = Table(title="Crafter Rollout Results", show_header=True, header_style="bold magenta")
425
+
426
+ table.add_column("Episode", style="cyan", width=12)
427
+ table.add_column("Steps", justify="right", style="green")
428
+ table.add_column("Reward", justify="right", style="yellow")
429
+ table.add_column("Achievements", justify="right", style="blue")
430
+ table.add_column("Resources", justify="center", style="magenta")
431
+ table.add_column("Final Status", justify="center")
432
+ table.add_column("Time (s)", justify="right", style="dim")
433
+ table.add_column("Avg LLM (s)", justify="right", style="dim")
434
+
435
+ for stats in all_stats:
436
+ # Format resources
437
+ resources = []
438
+ for item, count in stats.resources_collected.items():
439
+ resources.append(f"{item}:{count}")
440
+ resources_str = ", ".join(resources[:3]) if resources else "none"
441
+ if len(resources) > 3:
442
+ resources_str += "..."
443
+
444
+ # Format final status
445
+ status = f"H:{stats.final_health} F:{stats.final_hunger} T:{stats.final_thirst}"
446
+
447
+ # Color code based on performance
448
+ reward_style = "green" if stats.total_reward > 0 else "red"
449
+ ach_style = "green" if len(stats.achievements) > 0 else "dim"
450
+
451
+ table.add_row(
452
+ stats.episode_id.split("_")[-1][:8],
453
+ str(stats.steps),
454
+ f"[{reward_style}]{stats.total_reward:.1f}[/{reward_style}]",
455
+ f"[{ach_style}]{len(stats.achievements)}[/{ach_style}]",
456
+ resources_str,
457
+ status,
458
+ f"{stats.duration():.1f}",
459
+ f"{stats.avg_response_time():.1f}"
460
+ )
461
+
462
+ return table
463
+
464
+
465
+ def create_summary_panel(all_stats: List[EpisodeStats], config: RolloutConfig) -> Panel:
466
+ """Create a summary panel."""
467
+ total_episodes = len(all_stats)
468
+ successful_episodes = sum(1 for s in all_stats if s.total_reward > 0)
469
+
470
+ avg_reward = np.mean([s.total_reward for s in all_stats]) if all_stats else 0
471
+ avg_steps = np.mean([s.steps for s in all_stats]) if all_stats else 0
472
+ avg_achievements = np.mean([len(s.achievements) for s in all_stats]) if all_stats else 0
473
+
474
+ # Count all achievements
475
+ all_achievements = defaultdict(int)
476
+ for stats in all_stats:
477
+ for ach in stats.achievements:
478
+ all_achievements[ach] += 1
479
+
480
+ # Most common actions
481
+ all_actions = defaultdict(int)
482
+ for stats in all_stats:
483
+ for action, count in stats.actions_taken.items():
484
+ all_actions[action] += count
485
+
486
+ top_actions = sorted(all_actions.items(), key=lambda x: x[1], reverse=True)[:5]
487
+
488
+ summary_text = f"""[bold]Model:[/bold] {config.model_name}
489
+ [bold]Episodes:[/bold] {total_episodes} (Successful: {successful_episodes})
490
+ [bold]Average Reward:[/bold] {avg_reward:.2f}
491
+ [bold]Average Steps:[/bold] {avg_steps:.1f}
492
+ [bold]Average Achievements:[/bold] {avg_achievements:.1f}
493
+
494
+ [bold]Top Achievements:[/bold]
495
+ """
496
+
497
+ for ach, count in sorted(all_achievements.items(), key=lambda x: x[1], reverse=True)[:5]:
498
+ pct = (count / total_episodes) * 100
499
+ summary_text += f" • {ach}: {count} ({pct:.0f}%)\n"
500
+
501
+ summary_text += "\n[bold]Top Actions:[/bold]\n"
502
+ for action, count in top_actions:
503
+ summary_text += f" • {action}: {count}\n"
504
+
505
+ return Panel(summary_text, title="Summary Statistics", border_style="green")
506
+
507
+
508
+ async def main():
509
+ """Main function."""
510
+ parser = argparse.ArgumentParser(description="Run Crafter rollouts with Qwen models")
511
+ parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct",
512
+ help="Model name (e.g., Qwen/Qwen2.5-7B-Instruct)")
513
+ parser.add_argument("--episodes", type=int, default=10,
514
+ help="Number of episodes to run")
515
+ parser.add_argument("--max-steps", type=int, default=100,
516
+ help="Maximum steps per episode")
517
+ parser.add_argument("--difficulty", type=str, default="easy",
518
+ choices=["easy", "normal", "hard", "peaceful"],
519
+ help="Game difficulty")
520
+ parser.add_argument("--seed", type=int, default=None,
521
+ help="Random seed for reproducibility")
522
+ parser.add_argument("--temperature", type=float, default=0.7,
523
+ help="LLM temperature")
524
+ parser.add_argument("--save", action="store_true",
525
+ help="Save results to file")
526
+ parser.add_argument("--output", type=str, default=None,
527
+ help="Output file for results")
528
+ parser.add_argument("--skip-warmup", action="store_true",
529
+ help="Skip model warmup phase")
530
+
531
+ args = parser.parse_args()
532
+
533
+ # Create config
534
+ config = RolloutConfig(
535
+ model_name=args.model,
536
+ num_episodes=args.episodes,
537
+ max_steps_per_episode=args.max_steps,
538
+ difficulty=args.difficulty,
539
+ seed=args.seed,
540
+ temperature=args.temperature,
541
+ save_results=args.save,
542
+ output_file=args.output
543
+ )
544
+
545
+ # Set up logging - suppress httpx INFO logs
546
+ logging.basicConfig(level=logging.WARNING)
547
+ logging.getLogger("httpx").setLevel(logging.WARNING)
548
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
549
+
550
+ console.print(f"[bold green]🎮 Crafter Rollouts with {config.model_name}[/bold green]")
551
+ console.print(f"Episodes: {config.num_episodes}, Max steps: {config.max_steps_per_episode}")
552
+ console.print(f"Difficulty: {config.difficulty}, Temperature: {config.temperature}")
553
+
554
+ # Show expected routing
555
+ expected_category = get_model_size_category(config.model_name)
556
+ console.print(f"[dim]Expected Modal container: base_model_{expected_category}_generate[/dim]")
557
+ console.print()
558
+
559
+ # Warmup the model first
560
+ if not args.skip_warmup:
561
+ warmup_success = await warmup_model(config)
562
+ if not warmup_success:
563
+ console.print("[red]Failed to warmup model. Continue anyway? (y/n)[/red]")
564
+ response = input().strip().lower()
565
+ if response != 'y':
566
+ return
567
+ else:
568
+ console.print("[yellow]Skipping model warmup (--skip-warmup specified)[/yellow]")
569
+
570
+ console.print()
571
+ all_stats = []
572
+
573
+ # Run episodes with progress bar
574
+ with Progress(
575
+ SpinnerColumn(),
576
+ TextColumn("[progress.description]{task.description}"),
577
+ BarColumn(),
578
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
579
+ TimeRemainingColumn(),
580
+ console=console
581
+ ) as progress:
582
+
583
+ total_steps = config.num_episodes * config.max_steps_per_episode
584
+ task = progress.add_task(f"Running {config.num_episodes} episodes...", total=total_steps)
585
+
586
+ # Run episodes concurrently
587
+ tasks = []
588
+ for i in range(config.num_episodes):
589
+ episode_id = f"qwen_{i}_{uuid.uuid4().hex[:8]}"
590
+ task_coro = run_episode(episode_id, config, lambda: progress.update(task, advance=1))
591
+ tasks.append(task_coro)
592
+
593
+ # Limit concurrency to avoid overwhelming the services
594
+ sem = asyncio.Semaphore(3)
595
+ async def run_with_semaphore(coro):
596
+ async with sem:
597
+ return await coro
598
+
599
+ results = await asyncio.gather(*[run_with_semaphore(t) for t in tasks], return_exceptions=True)
600
+
601
+ for i, result in enumerate(results):
602
+ if isinstance(result, Exception):
603
+ console.print(f"[red]Episode {i} failed: {result}[/red]")
604
+ else:
605
+ all_stats.append(result)
606
+
607
+ # Display results
608
+ console.print()
609
+
610
+ if all_stats:
611
+ # Show results table
612
+ table = create_results_table(all_stats)
613
+ console.print(table)
614
+ console.print()
615
+
616
+ # Show summary
617
+ summary = create_summary_panel(all_stats, config)
618
+ console.print(summary)
619
+
620
+ # Save results if requested
621
+ if config.save_results:
622
+ output_file = config.output_file or f"qwen_rollouts_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
623
+
624
+ results_data = {
625
+ "config": config.dict(),
626
+ "timestamp": datetime.now().isoformat(),
627
+ "episodes": [
628
+ {
629
+ "episode_id": s.episode_id,
630
+ "steps": s.steps,
631
+ "total_reward": s.total_reward,
632
+ "achievements": s.achievements,
633
+ "resources_collected": dict(s.resources_collected),
634
+ "actions_taken": dict(s.actions_taken),
635
+ "final_health": s.final_health,
636
+ "final_hunger": s.final_hunger,
637
+ "final_thirst": s.final_thirst,
638
+ "duration": s.duration(),
639
+ "avg_response_time": s.avg_response_time(),
640
+ "termination_reason": s.termination_reason
641
+ }
642
+ for s in all_stats
643
+ ]
644
+ }
645
+
646
+ with open(output_file, "w") as f:
647
+ json.dump(results_data, f, indent=2)
648
+
649
+ console.print(f"\n[green]Results saved to: {output_file}[/green]")
650
+ else:
651
+ console.print("[red]No successful episodes completed![/red]")
652
+
653
+
654
+ if __name__ == "__main__":
655
+ asyncio.run(main())