synth-ai 0.2.4.dev7__py3-none-any.whl → 0.2.4.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +1 -1
- synth_ai/cli/__init__.py +6 -0
- synth_ai/cli/balance.py +3 -15
- synth_ai/cli/demo.py +68 -9
- synth_ai/cli/rl_demo.py +137 -0
- synth_ai/cli/root.py +65 -0
- synth_ai/config/base_url.py +47 -0
- synth_ai/demos/core/__init__.py +1 -0
- synth_ai/demos/core/cli.py +621 -0
- synth_ai/demos/demo_task_apps/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/core.py +374 -0
- synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/math/app.py +37 -0
- synth_ai/demos/demo_task_apps/math/config.toml +44 -0
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
- synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
- synth_ai/environments/examples/bandit/__init__.py +33 -0
- synth_ai/environments/examples/bandit/engine.py +294 -0
- synth_ai/environments/examples/bandit/environment.py +194 -0
- synth_ai/environments/examples/bandit/taskset.py +200 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
- synth_ai/environments/examples/crafter_classic/environment.py +41 -2
- synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
- synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
- synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
- synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
- synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/http.py +102 -0
- synth_ai/inference/__init__.py +7 -0
- synth_ai/inference/client.py +20 -0
- synth_ai/install_sqld.sh +40 -0
- synth_ai/jobs/client.py +246 -0
- synth_ai/learning/__init__.py +24 -0
- synth_ai/learning/client.py +149 -0
- synth_ai/learning/config.py +43 -0
- synth_ai/learning/constants.py +29 -0
- synth_ai/learning/ft_client.py +59 -0
- synth_ai/learning/health.py +43 -0
- synth_ai/learning/jobs.py +205 -0
- synth_ai/learning/rl_client.py +256 -0
- synth_ai/learning/sse.py +58 -0
- synth_ai/learning/validators.py +48 -0
- synth_ai/lm/core/main_v3.py +13 -0
- synth_ai/lm/core/synth_models.py +48 -0
- synth_ai/lm/core/vendor_clients.py +9 -6
- synth_ai/lm/vendors/core/openai_api.py +31 -3
- synth_ai/lm/vendors/openai_standard.py +45 -14
- synth_ai/lm/vendors/supported/custom_endpoint.py +12 -2
- synth_ai/lm/vendors/synth_client.py +372 -28
- synth_ai/rl/__init__.py +30 -0
- synth_ai/rl/contracts.py +32 -0
- synth_ai/rl/env_keys.py +137 -0
- synth_ai/rl/secrets.py +19 -0
- synth_ai/scripts/verify_rewards.py +100 -0
- synth_ai/task/__init__.py +10 -0
- synth_ai/task/contracts.py +120 -0
- synth_ai/task/health.py +28 -0
- synth_ai/task/validators.py +12 -0
- synth_ai/tracing_v3/hooks.py +3 -1
- synth_ai/tracing_v3/session_tracer.py +123 -2
- synth_ai/tracing_v3/turso/manager.py +218 -0
- synth_ai/tracing_v3/turso/models.py +53 -0
- synth_ai-0.2.4.dev9.dist-info/METADATA +91 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/RECORD +147 -30
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/entry_points.txt +1 -0
- synth_ai/tui/__init__.py +0 -1
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -340
- synth_ai-0.2.4.dev7.dist-info/METADATA +0 -193
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,655 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Run Crafter rollouts with Qwen models and display results in a table format
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import json
|
|
8
|
+
import uuid
|
|
9
|
+
import argparse
|
|
10
|
+
import logging
|
|
11
|
+
import time
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from typing import Dict, Any, Optional, List, Tuple
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
import httpx
|
|
16
|
+
import os
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
import numpy as np
|
|
19
|
+
from rich.console import Console
|
|
20
|
+
from rich.table import Table
|
|
21
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeRemainingColumn
|
|
22
|
+
from rich.live import Live
|
|
23
|
+
from rich.layout import Layout
|
|
24
|
+
from rich.panel import Panel
|
|
25
|
+
from collections import defaultdict
|
|
26
|
+
|
|
27
|
+
# Disable Langfuse
|
|
28
|
+
os.environ["LANGFUSE_ENABLED"] = "false"
|
|
29
|
+
os.environ["LANGFUSE_PUBLIC_KEY"] = "dummy"
|
|
30
|
+
os.environ["LANGFUSE_SECRET_KEY"] = "dummy"
|
|
31
|
+
|
|
32
|
+
# Import Crafter hooks
|
|
33
|
+
try:
|
|
34
|
+
from synth_ai.environments.examples.crafter_classic.trace_hooks import CRAFTER_HOOKS
|
|
35
|
+
except ImportError:
|
|
36
|
+
CRAFTER_HOOKS = []
|
|
37
|
+
|
|
38
|
+
# Service configuration
|
|
39
|
+
MODAL_BASE_URL = "https://synth-laboratories--unified-ft-service-fastapi-app.modal.run"
|
|
40
|
+
MODAL_API_KEY = os.environ.get("MODAL_API_KEY", "sk-test-11111111111111111111111111111111")
|
|
41
|
+
|
|
42
|
+
# Model size routing based on Modal service configuration
|
|
43
|
+
MODEL_SIZE_ROUTING = {
|
|
44
|
+
"0.5B": "small",
|
|
45
|
+
"1.5B": "small",
|
|
46
|
+
"3B": "small",
|
|
47
|
+
"7B": "medium",
|
|
48
|
+
"14B": "medium",
|
|
49
|
+
"32B": "large32",
|
|
50
|
+
"72B": "large72"
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
def get_model_size_category(model_name: str) -> str:
|
|
54
|
+
"""Get the size category for routing."""
|
|
55
|
+
for size, category in MODEL_SIZE_ROUTING.items():
|
|
56
|
+
if f"-{size}-" in model_name or model_name.endswith(f"-{size}"):
|
|
57
|
+
return category
|
|
58
|
+
return "medium" # Default to medium
|
|
59
|
+
|
|
60
|
+
# HTTP retry configuration
|
|
61
|
+
MAX_RETRIES = 3
|
|
62
|
+
BASE_DELAY = 0.1
|
|
63
|
+
MAX_DELAY = 2.0
|
|
64
|
+
HTTP_TIMEOUT = 120.0
|
|
65
|
+
|
|
66
|
+
console = Console()
|
|
67
|
+
|
|
68
|
+
class RolloutConfig(BaseModel):
|
|
69
|
+
"""Configuration for rollout evaluation."""
|
|
70
|
+
# Model settings
|
|
71
|
+
model_name: str = "Qwen/Qwen2.5-7B-Instruct"
|
|
72
|
+
temperature: float = 0.7
|
|
73
|
+
max_tokens: int = 512
|
|
74
|
+
|
|
75
|
+
# Evaluation settings
|
|
76
|
+
num_episodes: int = 10
|
|
77
|
+
max_steps_per_episode: int = 100
|
|
78
|
+
difficulty: str = "easy"
|
|
79
|
+
seed: Optional[int] = None
|
|
80
|
+
|
|
81
|
+
# Service settings
|
|
82
|
+
crafter_url: str = "http://localhost:8901"
|
|
83
|
+
llm_base_url: str = MODAL_BASE_URL
|
|
84
|
+
llm_api_key: str = MODAL_API_KEY
|
|
85
|
+
|
|
86
|
+
# Display settings
|
|
87
|
+
show_live_progress: bool = True
|
|
88
|
+
save_results: bool = True
|
|
89
|
+
output_file: Optional[str] = None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class EpisodeStats:
|
|
93
|
+
"""Track statistics for an episode."""
|
|
94
|
+
def __init__(self, episode_id: str):
|
|
95
|
+
self.episode_id = episode_id
|
|
96
|
+
self.steps = 0
|
|
97
|
+
self.total_reward = 0.0
|
|
98
|
+
self.achievements = []
|
|
99
|
+
self.final_health = 0
|
|
100
|
+
self.final_hunger = 0
|
|
101
|
+
self.final_thirst = 0
|
|
102
|
+
self.resources_collected = defaultdict(int)
|
|
103
|
+
self.actions_taken = defaultdict(int)
|
|
104
|
+
self.start_time = time.time()
|
|
105
|
+
self.end_time = None
|
|
106
|
+
self.termination_reason = None
|
|
107
|
+
self.llm_response_times = []
|
|
108
|
+
|
|
109
|
+
def duration(self) -> float:
|
|
110
|
+
if self.end_time:
|
|
111
|
+
return self.end_time - self.start_time
|
|
112
|
+
return time.time() - self.start_time
|
|
113
|
+
|
|
114
|
+
def avg_response_time(self) -> float:
|
|
115
|
+
if self.llm_response_times:
|
|
116
|
+
return np.mean(self.llm_response_times)
|
|
117
|
+
return 0.0
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
async def retry_http_request(client: httpx.AsyncClient, method: str, url: str, **kwargs) -> Any:
|
|
121
|
+
"""Retry HTTP requests with exponential backoff."""
|
|
122
|
+
for attempt in range(MAX_RETRIES):
|
|
123
|
+
try:
|
|
124
|
+
if attempt > 0:
|
|
125
|
+
await asyncio.sleep(BASE_DELAY * (2 ** (attempt - 1)))
|
|
126
|
+
|
|
127
|
+
response = await client.request(method, url, timeout=HTTP_TIMEOUT, **kwargs)
|
|
128
|
+
|
|
129
|
+
if response.status_code < 500:
|
|
130
|
+
return response
|
|
131
|
+
|
|
132
|
+
except Exception as e:
|
|
133
|
+
if attempt == MAX_RETRIES - 1:
|
|
134
|
+
raise e
|
|
135
|
+
|
|
136
|
+
raise Exception(f"Failed after {MAX_RETRIES} attempts")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
async def warmup_model(config: RolloutConfig, max_attempts: int = 30) -> bool:
|
|
140
|
+
"""Warmup the model by polling until it's ready."""
|
|
141
|
+
console.print(f"[yellow]Warming up {config.model_name}...[/yellow]")
|
|
142
|
+
|
|
143
|
+
# First try the warmup endpoint if available
|
|
144
|
+
async with httpx.AsyncClient() as client:
|
|
145
|
+
headers = {
|
|
146
|
+
"Authorization": f"Bearer {config.llm_api_key}",
|
|
147
|
+
"Content-Type": "application/json"
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
# Try warmup endpoint
|
|
151
|
+
try:
|
|
152
|
+
warmup_url = f"{config.llm_base_url}/warmup/{config.model_name}"
|
|
153
|
+
response = await client.post(warmup_url, headers=headers, timeout=30.0)
|
|
154
|
+
if response.status_code == 200:
|
|
155
|
+
console.print("[green]✓ Model warmup endpoint called[/green]")
|
|
156
|
+
except:
|
|
157
|
+
pass # Warmup endpoint might not exist
|
|
158
|
+
|
|
159
|
+
# Now poll with actual inference requests
|
|
160
|
+
test_messages = [
|
|
161
|
+
{"role": "user", "content": "Say 'ready' if you're loaded."}
|
|
162
|
+
]
|
|
163
|
+
|
|
164
|
+
for attempt in range(max_attempts):
|
|
165
|
+
try:
|
|
166
|
+
start_time = time.time()
|
|
167
|
+
response = await client.post(
|
|
168
|
+
f"{config.llm_base_url}/v1/chat/completions",
|
|
169
|
+
headers=headers,
|
|
170
|
+
json={
|
|
171
|
+
"model": config.model_name,
|
|
172
|
+
"messages": test_messages,
|
|
173
|
+
"temperature": 0.1,
|
|
174
|
+
"max_tokens": 10,
|
|
175
|
+
},
|
|
176
|
+
timeout=120.0
|
|
177
|
+
)
|
|
178
|
+
elapsed = time.time() - start_time
|
|
179
|
+
|
|
180
|
+
if response.status_code == 200:
|
|
181
|
+
data = response.json()
|
|
182
|
+
if "choices" in data and data["choices"]:
|
|
183
|
+
console.print(f"[green]✓ Model ready! (response time: {elapsed:.1f}s)[/green]")
|
|
184
|
+
return True
|
|
185
|
+
|
|
186
|
+
# If we get here, model is still loading
|
|
187
|
+
if elapsed > 10:
|
|
188
|
+
console.print(f"[yellow]Model is loading... attempt {attempt + 1}/{max_attempts} (took {elapsed:.1f}s)[/yellow]")
|
|
189
|
+
|
|
190
|
+
except httpx.TimeoutException:
|
|
191
|
+
console.print(f"[yellow]Timeout waiting for model... attempt {attempt + 1}/{max_attempts}[/yellow]")
|
|
192
|
+
except Exception as e:
|
|
193
|
+
console.print(f"[yellow]Error during warmup: {str(e)[:100]}[/yellow]")
|
|
194
|
+
|
|
195
|
+
# Wait before retrying
|
|
196
|
+
await asyncio.sleep(5)
|
|
197
|
+
|
|
198
|
+
console.print(f"[red]Failed to warmup model after {max_attempts} attempts[/red]")
|
|
199
|
+
return False
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
async def call_llm(messages: List[Dict[str, str]], config: RolloutConfig) -> Tuple[str, float]:
|
|
203
|
+
"""Call LLM and return response with timing."""
|
|
204
|
+
async with httpx.AsyncClient() as client:
|
|
205
|
+
headers = {
|
|
206
|
+
"Authorization": f"Bearer {config.llm_api_key}",
|
|
207
|
+
"Content-Type": "application/json"
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
payload = {
|
|
211
|
+
"model": config.model_name,
|
|
212
|
+
"messages": messages,
|
|
213
|
+
"temperature": config.temperature,
|
|
214
|
+
"max_tokens": config.max_tokens,
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
start_time = time.time()
|
|
219
|
+
response = await retry_http_request(
|
|
220
|
+
client,
|
|
221
|
+
"POST",
|
|
222
|
+
f"{config.llm_base_url}/v1/chat/completions",
|
|
223
|
+
headers=headers,
|
|
224
|
+
json=payload
|
|
225
|
+
)
|
|
226
|
+
elapsed = time.time() - start_time
|
|
227
|
+
|
|
228
|
+
if response.status_code != 200:
|
|
229
|
+
raise Exception(f"LLM API error: {response.status_code} - {response.text}")
|
|
230
|
+
|
|
231
|
+
data = response.json()
|
|
232
|
+
return data["choices"][0]["message"]["content"], elapsed
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def format_observation(obs: Dict[str, Any]) -> str:
|
|
236
|
+
"""Format observation into a concise prompt."""
|
|
237
|
+
inv = obs.get("inventory", {})
|
|
238
|
+
health = obs.get("health", 10)
|
|
239
|
+
hunger = obs.get("food", 10)
|
|
240
|
+
thirst = obs.get("drink", 10)
|
|
241
|
+
|
|
242
|
+
# Get nearby objects in a 5x5 view
|
|
243
|
+
semantic_map = obs.get("semantic_map")
|
|
244
|
+
if semantic_map is not None:
|
|
245
|
+
# Simple 5x5 view around player
|
|
246
|
+
view = []
|
|
247
|
+
for dy in range(-2, 3):
|
|
248
|
+
row = []
|
|
249
|
+
for dx in range(-2, 3):
|
|
250
|
+
if dx == 0 and dy == 0:
|
|
251
|
+
row.append("P")
|
|
252
|
+
else:
|
|
253
|
+
# Simplified - just show if something is there
|
|
254
|
+
row.append(".")
|
|
255
|
+
view.append(" ".join(row))
|
|
256
|
+
map_str = "\n".join(view)
|
|
257
|
+
else:
|
|
258
|
+
map_str = "Map unavailable"
|
|
259
|
+
|
|
260
|
+
# Format inventory (only non-zero items)
|
|
261
|
+
inv_items = [f"{k}:{v}" for k, v in inv.items()
|
|
262
|
+
if v > 0 and k not in ["health", "food", "drink", "energy"]]
|
|
263
|
+
inv_str = ", ".join(inv_items) if inv_items else "empty"
|
|
264
|
+
|
|
265
|
+
return f"""Status: Health={health}/10, Hunger={hunger}/10, Thirst={thirst}/10
|
|
266
|
+
Inventory: {inv_str}
|
|
267
|
+
Nearby (5x5, P=player):
|
|
268
|
+
{map_str}
|
|
269
|
+
|
|
270
|
+
What action should you take? Choose one:
|
|
271
|
+
move_left, move_right, move_up, move_down, do, sleep, place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe, make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword
|
|
272
|
+
|
|
273
|
+
Action:"""
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
async def run_episode(
|
|
277
|
+
episode_id: str,
|
|
278
|
+
config: RolloutConfig,
|
|
279
|
+
progress: Optional[Any] = None
|
|
280
|
+
) -> EpisodeStats:
|
|
281
|
+
"""Run a single episode."""
|
|
282
|
+
stats = EpisodeStats(episode_id)
|
|
283
|
+
|
|
284
|
+
async with httpx.AsyncClient() as client:
|
|
285
|
+
# Create environment
|
|
286
|
+
create_resp = await retry_http_request(
|
|
287
|
+
client,
|
|
288
|
+
"POST",
|
|
289
|
+
f"{config.crafter_url}/CrafterClassic/create",
|
|
290
|
+
json={
|
|
291
|
+
"instance_id": episode_id,
|
|
292
|
+
"render_mode": "rgb_array",
|
|
293
|
+
"difficulty": config.difficulty,
|
|
294
|
+
"seed": config.seed
|
|
295
|
+
}
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
env_data = create_resp.json()
|
|
299
|
+
instance_id = env_data["instance_id"]
|
|
300
|
+
|
|
301
|
+
# Reset environment
|
|
302
|
+
reset_resp = await retry_http_request(
|
|
303
|
+
client,
|
|
304
|
+
"POST",
|
|
305
|
+
f"{config.crafter_url}/CrafterClassic/{instance_id}/reset",
|
|
306
|
+
json={}
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
obs_data = reset_resp.json().get("private", {})
|
|
310
|
+
|
|
311
|
+
# System message for the agent
|
|
312
|
+
messages = [{
|
|
313
|
+
"role": "system",
|
|
314
|
+
"content": "You are playing Crafter, a survival game. Your goals are to: 1) Stay alive by maintaining health/hunger/thirst, 2) Gather resources (wood, stone, etc), 3) Craft tools and items. Respond with only the action name."
|
|
315
|
+
}]
|
|
316
|
+
|
|
317
|
+
# Action mapping
|
|
318
|
+
action_map = {
|
|
319
|
+
'noop': 0, 'move_left': 1, 'move_right': 2, 'move_up': 3,
|
|
320
|
+
'move_down': 4, 'do': 5, 'sleep': 6, 'place_stone': 7,
|
|
321
|
+
'place_table': 8, 'place_furnace': 9, 'place_plant': 10,
|
|
322
|
+
'make_wood_pickaxe': 11, 'make_stone_pickaxe': 12,
|
|
323
|
+
'make_iron_pickaxe': 13, 'make_wood_sword': 14,
|
|
324
|
+
'make_stone_sword': 15, 'make_iron_sword': 16
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
# Run episode
|
|
328
|
+
for step in range(config.max_steps_per_episode):
|
|
329
|
+
# Create prompt
|
|
330
|
+
prompt = format_observation(obs_data)
|
|
331
|
+
messages.append({"role": "user", "content": prompt})
|
|
332
|
+
|
|
333
|
+
# Get LLM response
|
|
334
|
+
try:
|
|
335
|
+
response_text, response_time = await call_llm(messages, config)
|
|
336
|
+
stats.llm_response_times.append(response_time)
|
|
337
|
+
|
|
338
|
+
# Parse action
|
|
339
|
+
action = None
|
|
340
|
+
response_lower = response_text.strip().lower()
|
|
341
|
+
for action_name in action_map.keys():
|
|
342
|
+
if action_name in response_lower:
|
|
343
|
+
action = action_name
|
|
344
|
+
break
|
|
345
|
+
|
|
346
|
+
if not action:
|
|
347
|
+
action = "do" # Default
|
|
348
|
+
|
|
349
|
+
stats.actions_taken[action] += 1
|
|
350
|
+
action_idx = action_map[action]
|
|
351
|
+
|
|
352
|
+
# Take action
|
|
353
|
+
step_payload = {
|
|
354
|
+
"env_id": instance_id,
|
|
355
|
+
"request_id": f"{episode_id}_step_{step}",
|
|
356
|
+
"action": {
|
|
357
|
+
"tool_calls": [{
|
|
358
|
+
"tool": "interact",
|
|
359
|
+
"args": {"action": action_idx}
|
|
360
|
+
}]
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
step_resp = await retry_http_request(
|
|
365
|
+
client,
|
|
366
|
+
"POST",
|
|
367
|
+
f"{config.crafter_url}/env/CrafterClassic/step",
|
|
368
|
+
json=step_payload
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
step_data = step_resp.json()
|
|
372
|
+
new_obs = step_data.get("private", {})
|
|
373
|
+
reward = step_data.get("reward", 0) or 0
|
|
374
|
+
done = step_data.get("done", False)
|
|
375
|
+
|
|
376
|
+
stats.total_reward += reward
|
|
377
|
+
stats.steps += 1
|
|
378
|
+
|
|
379
|
+
# Track achievements
|
|
380
|
+
for ach, status in new_obs.get("achievements_status", {}).items():
|
|
381
|
+
if status and ach not in stats.achievements:
|
|
382
|
+
stats.achievements.append(ach)
|
|
383
|
+
|
|
384
|
+
# Track resources
|
|
385
|
+
inv = new_obs.get("inventory", {})
|
|
386
|
+
for item, count in inv.items():
|
|
387
|
+
if item not in ["health", "food", "drink", "energy"] and count > 0:
|
|
388
|
+
stats.resources_collected[item] = max(stats.resources_collected[item], count)
|
|
389
|
+
|
|
390
|
+
# Update final stats
|
|
391
|
+
stats.final_health = inv.get("health", 0)
|
|
392
|
+
stats.final_hunger = inv.get("food", 0)
|
|
393
|
+
stats.final_thirst = inv.get("drink", 0)
|
|
394
|
+
|
|
395
|
+
# Keep conversation short
|
|
396
|
+
messages = messages[-4:] # Keep only recent context
|
|
397
|
+
messages.append({"role": "assistant", "content": action})
|
|
398
|
+
|
|
399
|
+
if done:
|
|
400
|
+
stats.termination_reason = step_data.get("termination_reason", "completed")
|
|
401
|
+
break
|
|
402
|
+
|
|
403
|
+
obs_data = new_obs
|
|
404
|
+
|
|
405
|
+
if progress:
|
|
406
|
+
progress()
|
|
407
|
+
|
|
408
|
+
except Exception as e:
|
|
409
|
+
stats.termination_reason = f"error: {str(e)}"
|
|
410
|
+
break
|
|
411
|
+
|
|
412
|
+
# Clean up
|
|
413
|
+
try:
|
|
414
|
+
await client.post(f"{config.crafter_url}/CrafterClassic/{instance_id}/terminate")
|
|
415
|
+
except:
|
|
416
|
+
pass
|
|
417
|
+
|
|
418
|
+
stats.end_time = time.time()
|
|
419
|
+
return stats
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def create_results_table(all_stats: List[EpisodeStats]) -> Table:
|
|
423
|
+
"""Create a rich table with results."""
|
|
424
|
+
table = Table(title="Crafter Rollout Results", show_header=True, header_style="bold magenta")
|
|
425
|
+
|
|
426
|
+
table.add_column("Episode", style="cyan", width=12)
|
|
427
|
+
table.add_column("Steps", justify="right", style="green")
|
|
428
|
+
table.add_column("Reward", justify="right", style="yellow")
|
|
429
|
+
table.add_column("Achievements", justify="right", style="blue")
|
|
430
|
+
table.add_column("Resources", justify="center", style="magenta")
|
|
431
|
+
table.add_column("Final Status", justify="center")
|
|
432
|
+
table.add_column("Time (s)", justify="right", style="dim")
|
|
433
|
+
table.add_column("Avg LLM (s)", justify="right", style="dim")
|
|
434
|
+
|
|
435
|
+
for stats in all_stats:
|
|
436
|
+
# Format resources
|
|
437
|
+
resources = []
|
|
438
|
+
for item, count in stats.resources_collected.items():
|
|
439
|
+
resources.append(f"{item}:{count}")
|
|
440
|
+
resources_str = ", ".join(resources[:3]) if resources else "none"
|
|
441
|
+
if len(resources) > 3:
|
|
442
|
+
resources_str += "..."
|
|
443
|
+
|
|
444
|
+
# Format final status
|
|
445
|
+
status = f"H:{stats.final_health} F:{stats.final_hunger} T:{stats.final_thirst}"
|
|
446
|
+
|
|
447
|
+
# Color code based on performance
|
|
448
|
+
reward_style = "green" if stats.total_reward > 0 else "red"
|
|
449
|
+
ach_style = "green" if len(stats.achievements) > 0 else "dim"
|
|
450
|
+
|
|
451
|
+
table.add_row(
|
|
452
|
+
stats.episode_id.split("_")[-1][:8],
|
|
453
|
+
str(stats.steps),
|
|
454
|
+
f"[{reward_style}]{stats.total_reward:.1f}[/{reward_style}]",
|
|
455
|
+
f"[{ach_style}]{len(stats.achievements)}[/{ach_style}]",
|
|
456
|
+
resources_str,
|
|
457
|
+
status,
|
|
458
|
+
f"{stats.duration():.1f}",
|
|
459
|
+
f"{stats.avg_response_time():.1f}"
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
return table
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def create_summary_panel(all_stats: List[EpisodeStats], config: RolloutConfig) -> Panel:
|
|
466
|
+
"""Create a summary panel."""
|
|
467
|
+
total_episodes = len(all_stats)
|
|
468
|
+
successful_episodes = sum(1 for s in all_stats if s.total_reward > 0)
|
|
469
|
+
|
|
470
|
+
avg_reward = np.mean([s.total_reward for s in all_stats]) if all_stats else 0
|
|
471
|
+
avg_steps = np.mean([s.steps for s in all_stats]) if all_stats else 0
|
|
472
|
+
avg_achievements = np.mean([len(s.achievements) for s in all_stats]) if all_stats else 0
|
|
473
|
+
|
|
474
|
+
# Count all achievements
|
|
475
|
+
all_achievements = defaultdict(int)
|
|
476
|
+
for stats in all_stats:
|
|
477
|
+
for ach in stats.achievements:
|
|
478
|
+
all_achievements[ach] += 1
|
|
479
|
+
|
|
480
|
+
# Most common actions
|
|
481
|
+
all_actions = defaultdict(int)
|
|
482
|
+
for stats in all_stats:
|
|
483
|
+
for action, count in stats.actions_taken.items():
|
|
484
|
+
all_actions[action] += count
|
|
485
|
+
|
|
486
|
+
top_actions = sorted(all_actions.items(), key=lambda x: x[1], reverse=True)[:5]
|
|
487
|
+
|
|
488
|
+
summary_text = f"""[bold]Model:[/bold] {config.model_name}
|
|
489
|
+
[bold]Episodes:[/bold] {total_episodes} (Successful: {successful_episodes})
|
|
490
|
+
[bold]Average Reward:[/bold] {avg_reward:.2f}
|
|
491
|
+
[bold]Average Steps:[/bold] {avg_steps:.1f}
|
|
492
|
+
[bold]Average Achievements:[/bold] {avg_achievements:.1f}
|
|
493
|
+
|
|
494
|
+
[bold]Top Achievements:[/bold]
|
|
495
|
+
"""
|
|
496
|
+
|
|
497
|
+
for ach, count in sorted(all_achievements.items(), key=lambda x: x[1], reverse=True)[:5]:
|
|
498
|
+
pct = (count / total_episodes) * 100
|
|
499
|
+
summary_text += f" • {ach}: {count} ({pct:.0f}%)\n"
|
|
500
|
+
|
|
501
|
+
summary_text += "\n[bold]Top Actions:[/bold]\n"
|
|
502
|
+
for action, count in top_actions:
|
|
503
|
+
summary_text += f" • {action}: {count}\n"
|
|
504
|
+
|
|
505
|
+
return Panel(summary_text, title="Summary Statistics", border_style="green")
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
async def main():
|
|
509
|
+
"""Main function."""
|
|
510
|
+
parser = argparse.ArgumentParser(description="Run Crafter rollouts with Qwen models")
|
|
511
|
+
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct",
|
|
512
|
+
help="Model name (e.g., Qwen/Qwen2.5-7B-Instruct)")
|
|
513
|
+
parser.add_argument("--episodes", type=int, default=10,
|
|
514
|
+
help="Number of episodes to run")
|
|
515
|
+
parser.add_argument("--max-steps", type=int, default=100,
|
|
516
|
+
help="Maximum steps per episode")
|
|
517
|
+
parser.add_argument("--difficulty", type=str, default="easy",
|
|
518
|
+
choices=["easy", "normal", "hard", "peaceful"],
|
|
519
|
+
help="Game difficulty")
|
|
520
|
+
parser.add_argument("--seed", type=int, default=None,
|
|
521
|
+
help="Random seed for reproducibility")
|
|
522
|
+
parser.add_argument("--temperature", type=float, default=0.7,
|
|
523
|
+
help="LLM temperature")
|
|
524
|
+
parser.add_argument("--save", action="store_true",
|
|
525
|
+
help="Save results to file")
|
|
526
|
+
parser.add_argument("--output", type=str, default=None,
|
|
527
|
+
help="Output file for results")
|
|
528
|
+
parser.add_argument("--skip-warmup", action="store_true",
|
|
529
|
+
help="Skip model warmup phase")
|
|
530
|
+
|
|
531
|
+
args = parser.parse_args()
|
|
532
|
+
|
|
533
|
+
# Create config
|
|
534
|
+
config = RolloutConfig(
|
|
535
|
+
model_name=args.model,
|
|
536
|
+
num_episodes=args.episodes,
|
|
537
|
+
max_steps_per_episode=args.max_steps,
|
|
538
|
+
difficulty=args.difficulty,
|
|
539
|
+
seed=args.seed,
|
|
540
|
+
temperature=args.temperature,
|
|
541
|
+
save_results=args.save,
|
|
542
|
+
output_file=args.output
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
# Set up logging - suppress httpx INFO logs
|
|
546
|
+
logging.basicConfig(level=logging.WARNING)
|
|
547
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
548
|
+
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
|
549
|
+
|
|
550
|
+
console.print(f"[bold green]🎮 Crafter Rollouts with {config.model_name}[/bold green]")
|
|
551
|
+
console.print(f"Episodes: {config.num_episodes}, Max steps: {config.max_steps_per_episode}")
|
|
552
|
+
console.print(f"Difficulty: {config.difficulty}, Temperature: {config.temperature}")
|
|
553
|
+
|
|
554
|
+
# Show expected routing
|
|
555
|
+
expected_category = get_model_size_category(config.model_name)
|
|
556
|
+
console.print(f"[dim]Expected Modal container: base_model_{expected_category}_generate[/dim]")
|
|
557
|
+
console.print()
|
|
558
|
+
|
|
559
|
+
# Warmup the model first
|
|
560
|
+
if not args.skip_warmup:
|
|
561
|
+
warmup_success = await warmup_model(config)
|
|
562
|
+
if not warmup_success:
|
|
563
|
+
console.print("[red]Failed to warmup model. Continue anyway? (y/n)[/red]")
|
|
564
|
+
response = input().strip().lower()
|
|
565
|
+
if response != 'y':
|
|
566
|
+
return
|
|
567
|
+
else:
|
|
568
|
+
console.print("[yellow]Skipping model warmup (--skip-warmup specified)[/yellow]")
|
|
569
|
+
|
|
570
|
+
console.print()
|
|
571
|
+
all_stats = []
|
|
572
|
+
|
|
573
|
+
# Run episodes with progress bar
|
|
574
|
+
with Progress(
|
|
575
|
+
SpinnerColumn(),
|
|
576
|
+
TextColumn("[progress.description]{task.description}"),
|
|
577
|
+
BarColumn(),
|
|
578
|
+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
|
579
|
+
TimeRemainingColumn(),
|
|
580
|
+
console=console
|
|
581
|
+
) as progress:
|
|
582
|
+
|
|
583
|
+
total_steps = config.num_episodes * config.max_steps_per_episode
|
|
584
|
+
task = progress.add_task(f"Running {config.num_episodes} episodes...", total=total_steps)
|
|
585
|
+
|
|
586
|
+
# Run episodes concurrently
|
|
587
|
+
tasks = []
|
|
588
|
+
for i in range(config.num_episodes):
|
|
589
|
+
episode_id = f"qwen_{i}_{uuid.uuid4().hex[:8]}"
|
|
590
|
+
task_coro = run_episode(episode_id, config, lambda: progress.update(task, advance=1))
|
|
591
|
+
tasks.append(task_coro)
|
|
592
|
+
|
|
593
|
+
# Limit concurrency to avoid overwhelming the services
|
|
594
|
+
sem = asyncio.Semaphore(3)
|
|
595
|
+
async def run_with_semaphore(coro):
|
|
596
|
+
async with sem:
|
|
597
|
+
return await coro
|
|
598
|
+
|
|
599
|
+
results = await asyncio.gather(*[run_with_semaphore(t) for t in tasks], return_exceptions=True)
|
|
600
|
+
|
|
601
|
+
for i, result in enumerate(results):
|
|
602
|
+
if isinstance(result, Exception):
|
|
603
|
+
console.print(f"[red]Episode {i} failed: {result}[/red]")
|
|
604
|
+
else:
|
|
605
|
+
all_stats.append(result)
|
|
606
|
+
|
|
607
|
+
# Display results
|
|
608
|
+
console.print()
|
|
609
|
+
|
|
610
|
+
if all_stats:
|
|
611
|
+
# Show results table
|
|
612
|
+
table = create_results_table(all_stats)
|
|
613
|
+
console.print(table)
|
|
614
|
+
console.print()
|
|
615
|
+
|
|
616
|
+
# Show summary
|
|
617
|
+
summary = create_summary_panel(all_stats, config)
|
|
618
|
+
console.print(summary)
|
|
619
|
+
|
|
620
|
+
# Save results if requested
|
|
621
|
+
if config.save_results:
|
|
622
|
+
output_file = config.output_file or f"qwen_rollouts_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
|
623
|
+
|
|
624
|
+
results_data = {
|
|
625
|
+
"config": config.dict(),
|
|
626
|
+
"timestamp": datetime.now().isoformat(),
|
|
627
|
+
"episodes": [
|
|
628
|
+
{
|
|
629
|
+
"episode_id": s.episode_id,
|
|
630
|
+
"steps": s.steps,
|
|
631
|
+
"total_reward": s.total_reward,
|
|
632
|
+
"achievements": s.achievements,
|
|
633
|
+
"resources_collected": dict(s.resources_collected),
|
|
634
|
+
"actions_taken": dict(s.actions_taken),
|
|
635
|
+
"final_health": s.final_health,
|
|
636
|
+
"final_hunger": s.final_hunger,
|
|
637
|
+
"final_thirst": s.final_thirst,
|
|
638
|
+
"duration": s.duration(),
|
|
639
|
+
"avg_response_time": s.avg_response_time(),
|
|
640
|
+
"termination_reason": s.termination_reason
|
|
641
|
+
}
|
|
642
|
+
for s in all_stats
|
|
643
|
+
]
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
with open(output_file, "w") as f:
|
|
647
|
+
json.dump(results_data, f, indent=2)
|
|
648
|
+
|
|
649
|
+
console.print(f"\n[green]Results saved to: {output_file}[/green]")
|
|
650
|
+
else:
|
|
651
|
+
console.print("[red]No successful episodes completed![/red]")
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
if __name__ == "__main__":
|
|
655
|
+
asyncio.run(main())
|