synth-ai 0.2.9.dev2__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +17 -0
- examples/common_old/backend.py +21 -0
- examples/crafter_debug_render.py +180 -0
- examples/evals_old/README.md +98 -0
- examples/evals_old/__init__.py +6 -0
- examples/evals_old/compare_models.py +1037 -0
- examples/evals_old/example_log.md +145 -0
- examples/evals_old/run_demo.sh +126 -0
- examples/evals_old/trace_analysis.py +270 -0
- examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
- examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
- examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
- examples/finetuning_old/synth_qwen_v1/README.md +68 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
- examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
- examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
- examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
- examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
- examples/finetuning_old/synth_qwen_v1/util.py +147 -0
- examples/rl/README.md +169 -0
- examples/rl/configs/eval_base_qwen.toml +15 -0
- examples/rl/configs/eval_rl_qwen.toml +11 -0
- examples/rl/configs/rl_from_base_qwen.toml +35 -0
- examples/rl/configs/rl_from_base_qwen17.toml +74 -0
- examples/rl/configs/rl_from_ft_qwen.toml +35 -0
- examples/rl/download_dataset.py +64 -0
- examples/rl/run_eval.py +435 -0
- examples/rl/run_rl_and_save.py +94 -0
- examples/rl/task_app/README.md +22 -0
- {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
- examples/rl/task_app/math_task_app.py +107 -0
- examples/rl_old/task_app.py +962 -0
- examples/run_crafter_demo.sh +10 -0
- examples/warming_up_to_rl/analyze_trace_db.py +420 -0
- examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
- examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
- examples/warming_up_to_rl/export_trace_sft.py +541 -0
- examples/warming_up_to_rl/groq_test.py +88 -0
- examples/warming_up_to_rl/manage_secrets.py +127 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/readme.md +172 -0
- examples/warming_up_to_rl/run_eval.py +434 -0
- examples/warming_up_to_rl/run_fft_and_save.py +309 -0
- examples/warming_up_to_rl/run_local_rollout.py +188 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
- examples/warming_up_to_rl/run_rl_and_save.py +101 -0
- examples/warming_up_to_rl/run_rollout_remote.py +129 -0
- examples/warming_up_to_rl/task_app/README.md +38 -0
- {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
- synth_ai/api/train/config_finder.py +18 -18
- synth_ai/api/train/env_resolver.py +28 -1
- synth_ai/cli/task_apps.py +264 -55
- synth_ai/demo_registry.py +7 -7
- synth_ai/demos/demo_task_apps/crafter/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +54 -0
- synth_ai/demos/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +165 -0
- synth_ai/task/apps/__init__.py +54 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +112 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Baseline evaluation script (public-friendly skeleton)
|
|
4
|
+
- Targets a task app (Crafter-like) via initialize/step/terminate
|
|
5
|
+
- Uses a TaskAppClient interface (to be implemented in synth-ai SDK)
|
|
6
|
+
- Keeps structure aligned with research/testing/crafter eval harness
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
import os
|
|
10
|
+
import json
|
|
11
|
+
import re
|
|
12
|
+
from typing import Any, Dict, List, Optional
|
|
13
|
+
from collections import Counter
|
|
14
|
+
import asyncio
|
|
15
|
+
import httpx
|
|
16
|
+
import argparse
|
|
17
|
+
import tomllib
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
class TaskAppClient:
|
|
21
|
+
"""Minimal async client for the task app initialize/step/terminate routes.
|
|
22
|
+
|
|
23
|
+
This is a public-friendly shim for examples, pending SDK surface consolidation.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, base_url: str, api_key: Optional[str] = None) -> None:
|
|
27
|
+
self.base_url = base_url.rstrip("/")
|
|
28
|
+
self.api_key = api_key
|
|
29
|
+
self._client: Optional[httpx.AsyncClient] = None
|
|
30
|
+
|
|
31
|
+
async def __aenter__(self) -> "TaskAppClient":
|
|
32
|
+
headers = {}
|
|
33
|
+
if self.api_key:
|
|
34
|
+
headers["X-API-Key"] = self.api_key
|
|
35
|
+
self._client = httpx.AsyncClient(
|
|
36
|
+
base_url=self.base_url, headers=headers, timeout=600.0, follow_redirects=True
|
|
37
|
+
)
|
|
38
|
+
return self
|
|
39
|
+
|
|
40
|
+
async def __aexit__(self, exc_type, exc, tb) -> None:
|
|
41
|
+
if self._client is not None:
|
|
42
|
+
await self._client.aclose()
|
|
43
|
+
self._client = None
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def client(self) -> httpx.AsyncClient:
|
|
47
|
+
if self._client is None:
|
|
48
|
+
# Fallback for direct use without context manager
|
|
49
|
+
headers = {}
|
|
50
|
+
if self.api_key:
|
|
51
|
+
headers["X-API-Key"] = self.api_key
|
|
52
|
+
self._client = httpx.AsyncClient(
|
|
53
|
+
base_url=self.base_url, headers=headers, timeout=600.0, follow_redirects=True
|
|
54
|
+
)
|
|
55
|
+
return self._client
|
|
56
|
+
|
|
57
|
+
async def initialize(self, env_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
|
58
|
+
"""POST /env/{env_name}/initialize (compat route supported in task app)."""
|
|
59
|
+
payload: Dict[str, Any] = {
|
|
60
|
+
"seed": config.get("seed"),
|
|
61
|
+
}
|
|
62
|
+
# Allow both world_config and config inputs; env routes will normalize difficulty
|
|
63
|
+
if "world_config" in config:
|
|
64
|
+
payload["world_config"] = config["world_config"]
|
|
65
|
+
if "config" in config:
|
|
66
|
+
payload["config"] = config["config"]
|
|
67
|
+
resp = await self.client.post(f"/env/{env_name}/initialize", json=payload)
|
|
68
|
+
resp.raise_for_status()
|
|
69
|
+
return resp.json()
|
|
70
|
+
|
|
71
|
+
async def step(self, env_name: str, env_id: str, tool_calls: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
72
|
+
"""POST /env/{env_name}/step with wrapped tool_calls in action."""
|
|
73
|
+
payload = {"env_id": env_id, "action": {"tool_calls": tool_calls}}
|
|
74
|
+
resp = await self.client.post(f"/env/{env_name}/step", json=payload)
|
|
75
|
+
resp.raise_for_status()
|
|
76
|
+
return resp.json()
|
|
77
|
+
|
|
78
|
+
async def terminate(self, env_name: str, env_id: str) -> Dict[str, Any]:
|
|
79
|
+
resp = await self.client.post(f"/env/{env_name}/terminate", json={"env_id": env_id})
|
|
80
|
+
resp.raise_for_status()
|
|
81
|
+
return resp.json()
|
|
82
|
+
|
|
83
|
+
async def get_info(self) -> Dict[str, Any]:
|
|
84
|
+
resp = await self.client.get("/info")
|
|
85
|
+
resp.raise_for_status()
|
|
86
|
+
return resp.json()
|
|
87
|
+
|
|
88
|
+
async def proxy_groq_chat(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
89
|
+
resp = await self.client.post("/proxy/groq/v1/chat/completions", json=payload)
|
|
90
|
+
resp.raise_for_status()
|
|
91
|
+
return resp.json()
|
|
92
|
+
|
|
93
|
+
async def vllm_chat(self, vllm_base_url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
94
|
+
async with httpx.AsyncClient(base_url=vllm_base_url.rstrip("/"), timeout=60.0) as c:
|
|
95
|
+
resp = await c.post("/v1/chat/completions", json=payload)
|
|
96
|
+
# Do not raise for status to surface body in errors
|
|
97
|
+
try:
|
|
98
|
+
data = resp.json()
|
|
99
|
+
except Exception:
|
|
100
|
+
data = {"error": "invalid_json", "raw": resp.text[:800]}
|
|
101
|
+
if resp.status_code >= 400:
|
|
102
|
+
return {"error": data}
|
|
103
|
+
return data
|
|
104
|
+
|
|
105
|
+
async def rollout(self, *, run_id: str, env_name: str, seed: int, difficulty: str, policy_name: str, policy_config: Dict[str, Any], max_turns: int) -> Dict[str, Any]:
|
|
106
|
+
ops: List[str] = []
|
|
107
|
+
for _ in range(max_turns):
|
|
108
|
+
ops.extend(["agent", "env"])
|
|
109
|
+
payload: Dict[str, Any] = {
|
|
110
|
+
"run_id": run_id,
|
|
111
|
+
"env": {
|
|
112
|
+
"env_name": env_name,
|
|
113
|
+
"config": {"difficulty": difficulty},
|
|
114
|
+
"seed": seed,
|
|
115
|
+
},
|
|
116
|
+
"policy": {
|
|
117
|
+
"policy_name": policy_name,
|
|
118
|
+
"config": policy_config,
|
|
119
|
+
},
|
|
120
|
+
"ops": ops,
|
|
121
|
+
"on_done": "terminate",
|
|
122
|
+
}
|
|
123
|
+
# Ensure X-API-Key is included
|
|
124
|
+
headers = {}
|
|
125
|
+
if self.api_key:
|
|
126
|
+
headers["X-API-Key"] = self.api_key
|
|
127
|
+
resp = await self.client.post("/rollout", json=payload, headers=headers)
|
|
128
|
+
resp.raise_for_status()
|
|
129
|
+
return resp.json()
|
|
130
|
+
|
|
131
|
+
TASK_APP_URL = os.getenv("TASK_APP_URL", "https://YOUR-TASK-APP.modal.run").rstrip("/")
|
|
132
|
+
MODEL = os.getenv("EVAL_MODEL", "qwen/qwen3-32b")
|
|
133
|
+
NUM_EPISODES = int(os.getenv("NUM_EPISODES", "3"))
|
|
134
|
+
MAX_TURNS = int(os.getenv("MAX_TURNS", "10"))
|
|
135
|
+
CONCURRENCY = int(os.getenv("CONCURRENCY", "1"))
|
|
136
|
+
|
|
137
|
+
def _interact_tool_schema() -> List[Dict[str, Any]]:
|
|
138
|
+
return [{
|
|
139
|
+
"type": "function",
|
|
140
|
+
"function": {
|
|
141
|
+
"name": "interact",
|
|
142
|
+
"description": "Perform actions in the Crafter environment.",
|
|
143
|
+
"parameters": {
|
|
144
|
+
"type": "object",
|
|
145
|
+
"properties": {
|
|
146
|
+
"actions": {"type": "array", "items": {"type": "string"}},
|
|
147
|
+
"reasoning": {"type": "string"},
|
|
148
|
+
},
|
|
149
|
+
"required": ["actions", "reasoning"],
|
|
150
|
+
},
|
|
151
|
+
},
|
|
152
|
+
}]
|
|
153
|
+
|
|
154
|
+
def _build_messages_from_observation(observation: Dict[str, Any], history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
155
|
+
inv = observation.get("inventory") or {}
|
|
156
|
+
pos = observation.get("player_position") or []
|
|
157
|
+
ach = observation.get("achievements_status") or {}
|
|
158
|
+
turns_taken = observation.get("num_steps_taken") or 0
|
|
159
|
+
user_lines: List[str] = []
|
|
160
|
+
user_lines.append("Environment: CrafterClassic")
|
|
161
|
+
user_lines.append(f"Player position: {pos}")
|
|
162
|
+
user_lines.append(f"Inventory: {json.dumps(inv, ensure_ascii=False)}")
|
|
163
|
+
unlocked = [k for k, v in ach.items() if v]
|
|
164
|
+
if unlocked:
|
|
165
|
+
user_lines.append(f"Unlocked achievements: {unlocked}")
|
|
166
|
+
user_lines.append("Provide 2-5 actions as a plan to explore and progress.")
|
|
167
|
+
# short history summary
|
|
168
|
+
if history:
|
|
169
|
+
last = history[-1]
|
|
170
|
+
user_lines.append(f"Last actions: {last.get('actions')}")
|
|
171
|
+
content = "\n".join(user_lines)
|
|
172
|
+
return [{"role": "user", "content": content}]
|
|
173
|
+
|
|
174
|
+
def _parse_tool_calls_from_openai_response(data: Dict[str, Any]) -> List[str]:
|
|
175
|
+
try:
|
|
176
|
+
choices = data.get("choices")
|
|
177
|
+
if isinstance(choices, list) and choices:
|
|
178
|
+
msg = choices[0].get("message", {}) if isinstance(choices[0], dict) else {}
|
|
179
|
+
tcs = msg.get("tool_calls")
|
|
180
|
+
if isinstance(tcs, list) and tcs:
|
|
181
|
+
fn = tcs[0].get("function", {}) if isinstance(tcs[0], dict) else {}
|
|
182
|
+
args = fn.get("arguments")
|
|
183
|
+
if isinstance(args, str):
|
|
184
|
+
try:
|
|
185
|
+
obj = json.loads(args)
|
|
186
|
+
except Exception:
|
|
187
|
+
obj = {}
|
|
188
|
+
elif isinstance(args, dict):
|
|
189
|
+
obj = args
|
|
190
|
+
else:
|
|
191
|
+
obj = {}
|
|
192
|
+
acts = obj.get("actions")
|
|
193
|
+
if isinstance(acts, list):
|
|
194
|
+
return [str(a) for a in acts][:5]
|
|
195
|
+
except Exception:
|
|
196
|
+
pass
|
|
197
|
+
# Fallback: parse JSON object from assistant text
|
|
198
|
+
try:
|
|
199
|
+
choices = data.get("choices")
|
|
200
|
+
msg = choices[0].get("message", {}) if isinstance(choices, list) and choices else {}
|
|
201
|
+
content = msg.get("content")
|
|
202
|
+
text = ""
|
|
203
|
+
if isinstance(content, str):
|
|
204
|
+
text = content
|
|
205
|
+
elif isinstance(content, list):
|
|
206
|
+
text = "\n".join(str(part.get("text")) for part in content if isinstance(part, dict) and part.get("text"))
|
|
207
|
+
for raw in re.findall(r"\{[\s\S]*\}", text or ""):
|
|
208
|
+
try:
|
|
209
|
+
obj = json.loads(raw)
|
|
210
|
+
if isinstance(obj, dict) and obj.get("tool") in ("interact", None):
|
|
211
|
+
acts = obj.get("args", {}).get("actions")
|
|
212
|
+
if isinstance(acts, list):
|
|
213
|
+
return [str(a) for a in acts][:5]
|
|
214
|
+
except Exception:
|
|
215
|
+
continue
|
|
216
|
+
except Exception:
|
|
217
|
+
pass
|
|
218
|
+
return []
|
|
219
|
+
|
|
220
|
+
async def _choose_actions_via_llm(client: TaskAppClient, provider: str, model: str, observation: Dict[str, Any], history: List[Dict[str, Any]]) -> List[str]:
|
|
221
|
+
messages = _build_messages_from_observation(observation, history)
|
|
222
|
+
payload: Dict[str, Any] = {
|
|
223
|
+
"model": model,
|
|
224
|
+
"messages": messages,
|
|
225
|
+
"tools": _interact_tool_schema(),
|
|
226
|
+
"max_tokens": 8192,
|
|
227
|
+
"temperature": 0.2,
|
|
228
|
+
}
|
|
229
|
+
if provider == "groq":
|
|
230
|
+
# Groq path: avoid forcing tool_choice to reduce 400 errors; proxy will synthesize if missing
|
|
231
|
+
data = await client.proxy_groq_chat(payload)
|
|
232
|
+
elif provider == "vllm":
|
|
233
|
+
info = await client.get_info()
|
|
234
|
+
vllm_base = ((info or {}).get("inference") or {}).get("base_url")
|
|
235
|
+
if not vllm_base:
|
|
236
|
+
return []
|
|
237
|
+
# For vLLM path, we can force the single tool
|
|
238
|
+
vllm_payload = dict(payload)
|
|
239
|
+
vllm_payload["tool_choice"] = {"type": "function", "function": {"name": "interact"}}
|
|
240
|
+
data = await client.vllm_chat(str(vllm_base), vllm_payload)
|
|
241
|
+
if isinstance(data, dict) and data.get("error"):
|
|
242
|
+
return []
|
|
243
|
+
else:
|
|
244
|
+
return []
|
|
245
|
+
actions = _parse_tool_calls_from_openai_response(data)
|
|
246
|
+
return actions or []
|
|
247
|
+
|
|
248
|
+
def _expand_actions_to_tool_calls(actions: List[str]) -> List[Dict[str, Any]]:
|
|
249
|
+
out: List[Dict[str, Any]] = []
|
|
250
|
+
for a in actions[:5]:
|
|
251
|
+
out.append({"tool": "interact", "args": {"action": a}})
|
|
252
|
+
return out
|
|
253
|
+
|
|
254
|
+
def _detect_provider(model: str) -> str:
|
|
255
|
+
m = (model or "").lower()
|
|
256
|
+
if "qwen/qwen3-32b" in m or "qwen-2.5-" in m or m.startswith("groq:"):
|
|
257
|
+
return "groq"
|
|
258
|
+
return "vllm"
|
|
259
|
+
|
|
260
|
+
def _rollout_inference_url_from_cfg(cfg: Dict[str, Any], default_vllm: Optional[str]) -> Optional[str]:
|
|
261
|
+
# Prefer explicit inference_url in TOML; else fall back to discovered vLLM base
|
|
262
|
+
url = cfg.get("inference_url")
|
|
263
|
+
if isinstance(url, str) and url:
|
|
264
|
+
return url
|
|
265
|
+
return default_vllm
|
|
266
|
+
|
|
267
|
+
async def eval_episode(client: TaskAppClient, seed: int) -> Dict[str, Any]:
|
|
268
|
+
env_name = "CrafterClassic"
|
|
269
|
+
history: List[Dict[str, Any]] = []
|
|
270
|
+
achievements: set[str] = set()
|
|
271
|
+
turns = 0
|
|
272
|
+
|
|
273
|
+
# Initialize environment
|
|
274
|
+
init_cfg: Dict[str, Any] = {"seed": seed, "world_config": {"difficulty": os.getenv("DIFFICULTY", "easy")}}
|
|
275
|
+
created = await client.initialize(env_name, init_cfg)
|
|
276
|
+
env_id = created.get("env_id")
|
|
277
|
+
if not isinstance(env_id, str) or not env_id:
|
|
278
|
+
raise RuntimeError(f"Invalid env_id from initialize: {created}")
|
|
279
|
+
done = False
|
|
280
|
+
provider = _detect_provider(MODEL)
|
|
281
|
+
observation = created.get("observation") if isinstance(created, dict) else None
|
|
282
|
+
if not isinstance(observation, dict):
|
|
283
|
+
observation = {}
|
|
284
|
+
|
|
285
|
+
try:
|
|
286
|
+
while turns < MAX_TURNS and not done:
|
|
287
|
+
# Ask LLM for actions; fallback to a simple exploratory pair
|
|
288
|
+
chosen_actions = await _choose_actions_via_llm(client, provider, MODEL, observation, history)
|
|
289
|
+
if not chosen_actions:
|
|
290
|
+
chosen_actions = ["move_up", "do"]
|
|
291
|
+
tool_calls = _expand_actions_to_tool_calls(chosen_actions)
|
|
292
|
+
step = await client.step(env_name, env_id, tool_calls)
|
|
293
|
+
done = bool(step.get("done"))
|
|
294
|
+
turns += 1
|
|
295
|
+
history.append({"actions": chosen_actions, "reasoning": "explore then interact"})
|
|
296
|
+
# Update observation for next turn if available
|
|
297
|
+
if isinstance(step, dict):
|
|
298
|
+
nxt = step.get("observation")
|
|
299
|
+
if isinstance(nxt, dict):
|
|
300
|
+
observation = nxt
|
|
301
|
+
finally:
|
|
302
|
+
try:
|
|
303
|
+
await client.terminate(env_name, env_id)
|
|
304
|
+
except Exception:
|
|
305
|
+
pass
|
|
306
|
+
|
|
307
|
+
return {"seed": seed, "turns": turns, "achievements": sorted(achievements)}
|
|
308
|
+
|
|
309
|
+
async def main() -> None:
|
|
310
|
+
# Best-effort load local .env if present (ensures ENVIRONMENT_API_KEY for rollout)
|
|
311
|
+
try:
|
|
312
|
+
env_path = Path(__file__).resolve().parent / ".env"
|
|
313
|
+
if env_path.exists():
|
|
314
|
+
for line in env_path.read_text(encoding="utf-8").splitlines():
|
|
315
|
+
line = line.strip()
|
|
316
|
+
if not line or line.startswith("#") or "=" not in line:
|
|
317
|
+
continue
|
|
318
|
+
k, v = line.split("=", 1)
|
|
319
|
+
k = k.strip()
|
|
320
|
+
v = v.strip().strip('"').strip("'")
|
|
321
|
+
os.environ.setdefault(k, v)
|
|
322
|
+
except Exception:
|
|
323
|
+
pass
|
|
324
|
+
|
|
325
|
+
parser = argparse.ArgumentParser(description="Baseline eval against task app with optional TOML config")
|
|
326
|
+
parser.add_argument("--toml", help="Path to TOML config file", default=None)
|
|
327
|
+
parser.add_argument("--use-rollout", action="store_true", help="Use server-side rollout endpoint for eval")
|
|
328
|
+
args = parser.parse_args()
|
|
329
|
+
|
|
330
|
+
global TASK_APP_URL, MODEL, NUM_EPISODES, MAX_TURNS, CONCURRENCY
|
|
331
|
+
cfg: Dict[str, Any] = {}
|
|
332
|
+
if args.toml:
|
|
333
|
+
with open(args.toml, "rb") as f:
|
|
334
|
+
cfg = tomllib.load(f)
|
|
335
|
+
# Map known keys; tolerate missing
|
|
336
|
+
TASK_APP_URL = (cfg.get("task_app_url") or TASK_APP_URL).rstrip("/")
|
|
337
|
+
MODEL = cfg.get("model") or MODEL
|
|
338
|
+
NUM_EPISODES = int(cfg.get("num_episodes") or NUM_EPISODES)
|
|
339
|
+
MAX_TURNS = int(cfg.get("max_turns") or MAX_TURNS)
|
|
340
|
+
CONCURRENCY = int(cfg.get("concurrency") or CONCURRENCY)
|
|
341
|
+
if "difficulty" in cfg:
|
|
342
|
+
os.environ["DIFFICULTY"] = str(cfg.get("difficulty"))
|
|
343
|
+
# Replace placeholder URLs with env if present
|
|
344
|
+
if "your-task-app.modal.run" in TASK_APP_URL.lower():
|
|
345
|
+
env_url = os.getenv("TASK_APP_URL")
|
|
346
|
+
if env_url:
|
|
347
|
+
TASK_APP_URL = env_url.rstrip("/")
|
|
348
|
+
else:
|
|
349
|
+
raise RuntimeError("TASK_APP_URL is a placeholder. Set task_app_url in TOML or export TASK_APP_URL.")
|
|
350
|
+
|
|
351
|
+
print(f"Task App: {TASK_APP_URL}")
|
|
352
|
+
print(f"Model: {MODEL} Episodes: {NUM_EPISODES} Max turns: {MAX_TURNS} Concurrency: {CONCURRENCY}")
|
|
353
|
+
sem = asyncio.Semaphore(max(CONCURRENCY, 1))
|
|
354
|
+
async with TaskAppClient(TASK_APP_URL, api_key=os.getenv("ENVIRONMENT_API_KEY")) as client:
|
|
355
|
+
if args.use_rollout:
|
|
356
|
+
# Use server-side rollout; derive inference URL per provider
|
|
357
|
+
info = await client.get_info()
|
|
358
|
+
default_vllm = ((info or {}).get("inference") or {}).get("base_url")
|
|
359
|
+
inf_url = _rollout_inference_url_from_cfg(cfg, default_vllm)
|
|
360
|
+
if not inf_url:
|
|
361
|
+
raise RuntimeError("Could not resolve inference URL for rollout")
|
|
362
|
+
async def _run(seed: int):
|
|
363
|
+
async with sem:
|
|
364
|
+
try:
|
|
365
|
+
run_id = f"eval-{seed}"
|
|
366
|
+
# Build policy config from TOML (explicit control; no server-side guessing)
|
|
367
|
+
policy_cfg: Dict[str, Any] = {
|
|
368
|
+
"model": cfg.get("model", MODEL),
|
|
369
|
+
"inference_url": inf_url,
|
|
370
|
+
}
|
|
371
|
+
for k in ("max_tokens", "temperature", "top_p", "thinking_mode", "thinking_budget", "use_tools"):
|
|
372
|
+
if k in cfg and cfg.get(k) is not None:
|
|
373
|
+
policy_cfg[k] = cfg.get(k)
|
|
374
|
+
|
|
375
|
+
r = await client.rollout(
|
|
376
|
+
run_id=run_id,
|
|
377
|
+
env_name="crafter",
|
|
378
|
+
seed=seed,
|
|
379
|
+
difficulty=os.getenv("DIFFICULTY", "easy"),
|
|
380
|
+
policy_name=cfg.get("policy_name", "crafter"),
|
|
381
|
+
policy_config=policy_cfg,
|
|
382
|
+
max_turns=MAX_TURNS,
|
|
383
|
+
)
|
|
384
|
+
# Extract achievements count if present
|
|
385
|
+
ach = []
|
|
386
|
+
try:
|
|
387
|
+
trajs = r.get("trajectories") or []
|
|
388
|
+
final_obs = (trajs[0].get("final") or {}).get("observation") if trajs and isinstance(trajs[0], dict) else None
|
|
389
|
+
ach_map = (final_obs or {}).get("achievements_status") if isinstance(final_obs, dict) else None
|
|
390
|
+
if isinstance(ach_map, dict):
|
|
391
|
+
ach = sorted([k for k, v in ach_map.items() if v])
|
|
392
|
+
except Exception:
|
|
393
|
+
pass
|
|
394
|
+
length = 0
|
|
395
|
+
try:
|
|
396
|
+
trajs = r.get("trajectories") or []
|
|
397
|
+
if trajs and isinstance(trajs[0], dict):
|
|
398
|
+
length = int(trajs[0].get("length") or 0)
|
|
399
|
+
except Exception:
|
|
400
|
+
pass
|
|
401
|
+
return {"seed": seed, "turns": length, "achievements": ach}
|
|
402
|
+
except Exception as e:
|
|
403
|
+
return {"seed": seed, "turns": 0, "achievements": [], "error": str(e)}
|
|
404
|
+
results = await asyncio.gather(*[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)], return_exceptions=False)
|
|
405
|
+
# Aggregate summary
|
|
406
|
+
counts = [len(r.get("achievements") or []) for r in results if isinstance(r, dict)]
|
|
407
|
+
turns = [int(r.get("turns") or 0) for r in results if isinstance(r, dict)]
|
|
408
|
+
all_ach = Counter()
|
|
409
|
+
for r in results:
|
|
410
|
+
try:
|
|
411
|
+
for a in r.get("achievements") or []:
|
|
412
|
+
all_ach[a] += 1
|
|
413
|
+
except Exception:
|
|
414
|
+
pass
|
|
415
|
+
summary = {
|
|
416
|
+
"episodes": results,
|
|
417
|
+
"aggregate": {
|
|
418
|
+
"completed": sum(1 for r in results if not r.get("error")),
|
|
419
|
+
"total": len(results),
|
|
420
|
+
"avg_turns": (sum(turns) / len(turns)) if turns else 0.0,
|
|
421
|
+
"avg_achievements": (sum(counts) / len(counts)) if counts else 0.0,
|
|
422
|
+
"achievements_freq": dict(all_ach),
|
|
423
|
+
},
|
|
424
|
+
}
|
|
425
|
+
print(json.dumps(summary, indent=2))
|
|
426
|
+
else:
|
|
427
|
+
async def _run(seed: int):
|
|
428
|
+
async with sem:
|
|
429
|
+
return await eval_episode(client, seed)
|
|
430
|
+
results = await asyncio.gather(*[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)])
|
|
431
|
+
print(json.dumps({"episodes": results}, indent=2))
|
|
432
|
+
|
|
433
|
+
if __name__ == "__main__":
|
|
434
|
+
asyncio.run(main())
|