synth-ai 0.2.9.dev2__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +17 -0
- examples/common_old/backend.py +21 -0
- examples/crafter_debug_render.py +180 -0
- examples/evals_old/README.md +98 -0
- examples/evals_old/__init__.py +6 -0
- examples/evals_old/compare_models.py +1037 -0
- examples/evals_old/example_log.md +145 -0
- examples/evals_old/run_demo.sh +126 -0
- examples/evals_old/trace_analysis.py +270 -0
- examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
- examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
- examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
- examples/finetuning_old/synth_qwen_v1/README.md +68 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
- examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
- examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
- examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
- examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
- examples/finetuning_old/synth_qwen_v1/util.py +147 -0
- examples/rl/README.md +169 -0
- examples/rl/configs/eval_base_qwen.toml +15 -0
- examples/rl/configs/eval_rl_qwen.toml +11 -0
- examples/rl/configs/rl_from_base_qwen.toml +35 -0
- examples/rl/configs/rl_from_base_qwen17.toml +74 -0
- examples/rl/configs/rl_from_ft_qwen.toml +35 -0
- examples/rl/download_dataset.py +64 -0
- examples/rl/run_eval.py +435 -0
- examples/rl/run_rl_and_save.py +94 -0
- examples/rl/task_app/README.md +22 -0
- {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
- examples/rl/task_app/math_task_app.py +107 -0
- examples/rl_old/task_app.py +962 -0
- examples/run_crafter_demo.sh +10 -0
- examples/warming_up_to_rl/analyze_trace_db.py +420 -0
- examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
- examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
- examples/warming_up_to_rl/export_trace_sft.py +541 -0
- examples/warming_up_to_rl/groq_test.py +88 -0
- examples/warming_up_to_rl/manage_secrets.py +127 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/readme.md +172 -0
- examples/warming_up_to_rl/run_eval.py +434 -0
- examples/warming_up_to_rl/run_fft_and_save.py +309 -0
- examples/warming_up_to_rl/run_local_rollout.py +188 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
- examples/warming_up_to_rl/run_rl_and_save.py +101 -0
- examples/warming_up_to_rl/run_rollout_remote.py +129 -0
- examples/warming_up_to_rl/task_app/README.md +38 -0
- {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
- synth_ai/api/train/config_finder.py +18 -18
- synth_ai/api/train/env_resolver.py +28 -1
- synth_ai/cli/task_apps.py +264 -55
- synth_ai/demo_registry.py +7 -7
- synth_ai/demos/demo_task_apps/crafter/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +54 -0
- synth_ai/demos/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +165 -0
- synth_ai/task/apps/__init__.py +54 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +112 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
examples/rl/run_eval.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Evaluate math single-step task policies against the task app."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
import httpx
|
|
14
|
+
import tomllib
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TaskAppClient:
|
|
18
|
+
"""Minimal async client for math single-step task app."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, base_url: str, api_key: Optional[str] = None) -> None:
|
|
21
|
+
self.base_url = base_url.rstrip("/")
|
|
22
|
+
self.api_key = api_key
|
|
23
|
+
self._client: Optional[httpx.AsyncClient] = None
|
|
24
|
+
|
|
25
|
+
async def __aenter__(self) -> "TaskAppClient":
|
|
26
|
+
headers = {"X-API-Key": self.api_key} if self.api_key else {}
|
|
27
|
+
self._client = httpx.AsyncClient(
|
|
28
|
+
base_url=self.base_url, headers=headers, timeout=httpx.Timeout(120.0), follow_redirects=True
|
|
29
|
+
)
|
|
30
|
+
return self
|
|
31
|
+
|
|
32
|
+
async def __aexit__(self, exc_type, exc, tb) -> None:
|
|
33
|
+
if self._client:
|
|
34
|
+
await self._client.aclose()
|
|
35
|
+
self._client = None
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def client(self) -> httpx.AsyncClient:
|
|
39
|
+
if self._client is None:
|
|
40
|
+
headers = {"X-API-Key": self.api_key} if self.api_key else {}
|
|
41
|
+
self._client = httpx.AsyncClient(
|
|
42
|
+
base_url=self.base_url, headers=headers, timeout=httpx.Timeout(120.0), follow_redirects=True
|
|
43
|
+
)
|
|
44
|
+
return self._client
|
|
45
|
+
|
|
46
|
+
async def initialize(self, split: str, seed: int | None) -> Dict[str, Any]:
|
|
47
|
+
payload: Dict[str, Any] = {"config": {"split": split}}
|
|
48
|
+
if seed is not None:
|
|
49
|
+
payload["seed"] = seed
|
|
50
|
+
resp = await self.client.post("/env/math/initialize", json=payload)
|
|
51
|
+
resp.raise_for_status()
|
|
52
|
+
return resp.json()
|
|
53
|
+
|
|
54
|
+
async def step(self, env_id: str, tool_calls: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
55
|
+
payload = {"env_id": env_id, "action": {"tool_calls": tool_calls}}
|
|
56
|
+
resp = await self.client.post("/env/math/step", json=payload)
|
|
57
|
+
resp.raise_for_status()
|
|
58
|
+
return resp.json()
|
|
59
|
+
|
|
60
|
+
async def terminate(self, env_id: str) -> None:
|
|
61
|
+
try:
|
|
62
|
+
await self.client.post("/env/math/terminate", json={"env_id": env_id})
|
|
63
|
+
except Exception:
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
async def get_info(self) -> Dict[str, Any]:
|
|
67
|
+
resp = await self.client.get("/info")
|
|
68
|
+
resp.raise_for_status()
|
|
69
|
+
return resp.json()
|
|
70
|
+
|
|
71
|
+
async def rollout(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
72
|
+
resp = await self.client.post("/rollout", json=payload)
|
|
73
|
+
resp.raise_for_status()
|
|
74
|
+
return resp.json()
|
|
75
|
+
|
|
76
|
+
async def post_inference(
|
|
77
|
+
self,
|
|
78
|
+
url: str,
|
|
79
|
+
payload: Dict[str, Any],
|
|
80
|
+
*,
|
|
81
|
+
headers: Dict[str, str] | None = None,
|
|
82
|
+
) -> Dict[str, Any]:
|
|
83
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as c:
|
|
84
|
+
resp = await c.post(url, json=payload, headers=headers)
|
|
85
|
+
resp.raise_for_status()
|
|
86
|
+
return resp.json()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
TOOL_NAME = "math_submit"
|
|
90
|
+
DEFAULT_SPLIT = os.getenv("MATH_EVAL_DEFAULT_SPLIT", "validation")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _math_tool_schema() -> List[Dict[str, Any]]:
|
|
94
|
+
return [
|
|
95
|
+
{
|
|
96
|
+
"type": "function",
|
|
97
|
+
"function": {
|
|
98
|
+
"name": TOOL_NAME,
|
|
99
|
+
"description": "Submit the final answer for the math problem.",
|
|
100
|
+
"parameters": {
|
|
101
|
+
"type": "object",
|
|
102
|
+
"properties": {
|
|
103
|
+
"answer": {
|
|
104
|
+
"type": "string",
|
|
105
|
+
"description": "Final answer in simplest form",
|
|
106
|
+
}
|
|
107
|
+
,
|
|
108
|
+
"explanation": {
|
|
109
|
+
"type": "string",
|
|
110
|
+
"description": "Optional explanation of reasoning",
|
|
111
|
+
},
|
|
112
|
+
},
|
|
113
|
+
"required": ["answer"],
|
|
114
|
+
"additionalProperties": False,
|
|
115
|
+
},
|
|
116
|
+
},
|
|
117
|
+
}
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _build_messages(problem: str) -> List[Dict[str, Any]]:
|
|
122
|
+
return [
|
|
123
|
+
{
|
|
124
|
+
"role": "system",
|
|
125
|
+
"content": (
|
|
126
|
+
"You solve math problems. Always respond with a single math_submit tool call "
|
|
127
|
+
"containing only the final answer."
|
|
128
|
+
),
|
|
129
|
+
},
|
|
130
|
+
{
|
|
131
|
+
"role": "user",
|
|
132
|
+
"content": f"Problem:\n{problem}\nReturn the final answer via math_submit.",
|
|
133
|
+
},
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _parse_tool_calls(data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
138
|
+
choices = data.get("choices") or []
|
|
139
|
+
if not choices:
|
|
140
|
+
return []
|
|
141
|
+
message = choices[0].get("message") or {}
|
|
142
|
+
raw_calls = message.get("tool_calls") or []
|
|
143
|
+
tool_calls: List[Dict[str, Any]] = []
|
|
144
|
+
for call in raw_calls:
|
|
145
|
+
function = call.get("function") or {}
|
|
146
|
+
name = function.get("name")
|
|
147
|
+
arguments = function.get("arguments")
|
|
148
|
+
parsed_args: Dict[str, Any]
|
|
149
|
+
if isinstance(arguments, str):
|
|
150
|
+
try:
|
|
151
|
+
parsed_args = json.loads(arguments)
|
|
152
|
+
except Exception:
|
|
153
|
+
parsed_args = {}
|
|
154
|
+
elif isinstance(arguments, dict):
|
|
155
|
+
parsed_args = dict(arguments)
|
|
156
|
+
else:
|
|
157
|
+
parsed_args = {}
|
|
158
|
+
tool_calls.append({"tool": name, "args": parsed_args})
|
|
159
|
+
return tool_calls
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _detect_provider(model: str, hint: Optional[str]) -> str:
|
|
163
|
+
if hint:
|
|
164
|
+
return hint.lower()
|
|
165
|
+
lowered = (model or "").lower()
|
|
166
|
+
if lowered.startswith("groq:"):
|
|
167
|
+
return "groq"
|
|
168
|
+
return "generic"
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _resolve_inference_url(base_url: str) -> str:
|
|
172
|
+
normalized = (base_url or "").rstrip("/")
|
|
173
|
+
if not normalized:
|
|
174
|
+
raise RuntimeError("inference_url cannot be empty")
|
|
175
|
+
if normalized.endswith("/v1/chat/completions"):
|
|
176
|
+
return normalized
|
|
177
|
+
if normalized.endswith("/chat/completions"):
|
|
178
|
+
return normalized
|
|
179
|
+
if normalized.endswith("/v1"):
|
|
180
|
+
return f"{normalized}/chat/completions"
|
|
181
|
+
if "/v1/" in normalized:
|
|
182
|
+
return f"{normalized}/chat/completions"
|
|
183
|
+
return f"{normalized}/v1/chat/completions"
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
async def _choose_actions(
|
|
187
|
+
client: TaskAppClient,
|
|
188
|
+
provider: str,
|
|
189
|
+
model: str,
|
|
190
|
+
problem: str,
|
|
191
|
+
policy_cfg: Dict[str, Any],
|
|
192
|
+
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
|
193
|
+
messages = _build_messages(problem)
|
|
194
|
+
payload: Dict[str, Any] = {
|
|
195
|
+
"model": model,
|
|
196
|
+
"messages": messages,
|
|
197
|
+
"tools": _math_tool_schema(),
|
|
198
|
+
"tool_choice": {"type": "function", "function": {"name": TOOL_NAME}},
|
|
199
|
+
"temperature": policy_cfg.get("temperature", 0.0),
|
|
200
|
+
"top_p": policy_cfg.get("top_p", 1.0),
|
|
201
|
+
"max_tokens": policy_cfg.get("max_tokens", 256),
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
if provider == "groq":
|
|
205
|
+
# Task app proxies Groq requests; reuse existing headers on the client
|
|
206
|
+
response = await client.client.post(
|
|
207
|
+
"/proxy/groq/v1/chat/completions", json=payload
|
|
208
|
+
)
|
|
209
|
+
response.raise_for_status()
|
|
210
|
+
body = response.json()
|
|
211
|
+
else:
|
|
212
|
+
inference_url = policy_cfg.get("inference_url")
|
|
213
|
+
if not inference_url:
|
|
214
|
+
raise RuntimeError("inference_url required for non-groq evaluations")
|
|
215
|
+
headers = dict(policy_cfg.get("headers") or {})
|
|
216
|
+
for key, value in (policy_cfg.get("extra_headers") or {}).items():
|
|
217
|
+
headers.setdefault(key, value)
|
|
218
|
+
final_url = _resolve_inference_url(inference_url)
|
|
219
|
+
try:
|
|
220
|
+
response = await client.client.post(
|
|
221
|
+
final_url,
|
|
222
|
+
json=payload,
|
|
223
|
+
headers=headers or None,
|
|
224
|
+
)
|
|
225
|
+
except httpx.ReadTimeout as exc:
|
|
226
|
+
raise RuntimeError(
|
|
227
|
+
"Inference request timed out. Check the inference service." ) from exc
|
|
228
|
+
try:
|
|
229
|
+
body = response.json()
|
|
230
|
+
except Exception:
|
|
231
|
+
body = {"raw": response.text[:800]}
|
|
232
|
+
if response.status_code >= 500:
|
|
233
|
+
raise RuntimeError(
|
|
234
|
+
f"Inference server error {response.status_code}: {body}")
|
|
235
|
+
if response.status_code >= 400:
|
|
236
|
+
raise RuntimeError(
|
|
237
|
+
f"Inference request invalid ({response.status_code}): {body}")
|
|
238
|
+
tool_calls = _parse_tool_calls(body)
|
|
239
|
+
return tool_calls, body
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _tool_to_answer(tool_calls: List[Dict[str, Any]]) -> str:
|
|
243
|
+
if not tool_calls:
|
|
244
|
+
return ""
|
|
245
|
+
args = tool_calls[0].get("args") or {}
|
|
246
|
+
answer = str(args.get("answer") or "")
|
|
247
|
+
return answer.strip()
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
async def eval_episode(
|
|
251
|
+
client: TaskAppClient,
|
|
252
|
+
*,
|
|
253
|
+
split: str,
|
|
254
|
+
seed: Optional[int],
|
|
255
|
+
model: str,
|
|
256
|
+
provider: str,
|
|
257
|
+
policy_cfg: Dict[str, Any],
|
|
258
|
+
) -> Dict[str, Any]:
|
|
259
|
+
created = await client.initialize(split, seed)
|
|
260
|
+
env_id = created["env_id"]
|
|
261
|
+
observation = created.get("observation") or {}
|
|
262
|
+
problem = observation.get("problem") or ""
|
|
263
|
+
|
|
264
|
+
tool_calls, raw_response = await _choose_actions(client, provider, model, problem, policy_cfg)
|
|
265
|
+
answer = _tool_to_answer(tool_calls)
|
|
266
|
+
result = await client.step(env_id, tool_calls)
|
|
267
|
+
await client.terminate(env_id)
|
|
268
|
+
|
|
269
|
+
info = result.get("info") or {}
|
|
270
|
+
reward = result.get("reward") or 0.0
|
|
271
|
+
status = info.get("status") or ("correct" if reward > 0 else "incorrect")
|
|
272
|
+
return {
|
|
273
|
+
"seed": seed,
|
|
274
|
+
"split": split,
|
|
275
|
+
"problem": problem,
|
|
276
|
+
"answer": answer,
|
|
277
|
+
"expected": info.get("expected_answer"),
|
|
278
|
+
"reward": reward,
|
|
279
|
+
"status": status,
|
|
280
|
+
"correct": bool(info.get("correct")),
|
|
281
|
+
"raw_response": raw_response,
|
|
282
|
+
"tool_calls": tool_calls,
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
async def eval_via_rollout(
|
|
287
|
+
client: TaskAppClient,
|
|
288
|
+
*,
|
|
289
|
+
run_id: str,
|
|
290
|
+
split: str,
|
|
291
|
+
seed: Optional[int],
|
|
292
|
+
model: str,
|
|
293
|
+
policy_cfg: Dict[str, Any],
|
|
294
|
+
) -> Dict[str, Any]:
|
|
295
|
+
payload = {
|
|
296
|
+
"run_id": run_id,
|
|
297
|
+
"env": {
|
|
298
|
+
"env_name": "math",
|
|
299
|
+
"config": {"split": split},
|
|
300
|
+
"seed": seed,
|
|
301
|
+
},
|
|
302
|
+
"policy": {
|
|
303
|
+
"policy_name": "math-single-step",
|
|
304
|
+
"config": policy_cfg,
|
|
305
|
+
},
|
|
306
|
+
"ops": ["agent", "env"],
|
|
307
|
+
"on_done": "terminate",
|
|
308
|
+
}
|
|
309
|
+
resp = await client.rollout(payload)
|
|
310
|
+
trajs = resp.get("trajectories") or []
|
|
311
|
+
if not trajs:
|
|
312
|
+
return {"reward": 0.0, "correct": False, "status": "missing"}
|
|
313
|
+
traj = trajs[0]
|
|
314
|
+
steps = traj.get("steps") or []
|
|
315
|
+
step = steps[0] if steps else {}
|
|
316
|
+
info = step.get("info") or {}
|
|
317
|
+
return {
|
|
318
|
+
"seed": seed,
|
|
319
|
+
"split": split,
|
|
320
|
+
"problem": observation.get("problem"),
|
|
321
|
+
"answer": _tool_to_answer(step.get("tool_calls") or []),
|
|
322
|
+
"expected": info.get("expected_answer"),
|
|
323
|
+
"reward": step.get("reward") or 0.0,
|
|
324
|
+
"status": info.get("status"),
|
|
325
|
+
"correct": bool(info.get("correct")),
|
|
326
|
+
"raw_response": resp,
|
|
327
|
+
"tool_calls": step.get("tool_calls") or [],
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _load_config(path: Optional[str]) -> Dict[str, Any]:
|
|
332
|
+
if not path:
|
|
333
|
+
return {}
|
|
334
|
+
with open(path, "rb") as fh:
|
|
335
|
+
return tomllib.load(fh)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _default_policy_cfg(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
339
|
+
policy = dict(cfg.get("policy") or {})
|
|
340
|
+
if "inference_url" not in policy:
|
|
341
|
+
env_url = os.getenv("INFERENCE_URL")
|
|
342
|
+
if env_url:
|
|
343
|
+
policy["inference_url"] = env_url
|
|
344
|
+
for key in ("max_tokens", "temperature", "top_p", "headers", "extra_headers"):
|
|
345
|
+
if key not in policy and key in cfg:
|
|
346
|
+
policy[key] = cfg[key]
|
|
347
|
+
extra_headers = dict(policy.get("extra_headers") or {})
|
|
348
|
+
headers = dict(policy.get("headers") or {})
|
|
349
|
+
if "Authorization" not in headers and "Authorization" not in extra_headers:
|
|
350
|
+
synth_key = os.getenv("SYNTH_API_KEY")
|
|
351
|
+
if synth_key:
|
|
352
|
+
extra_headers["Authorization"] = f"Bearer {synth_key}"
|
|
353
|
+
if extra_headers:
|
|
354
|
+
policy["extra_headers"] = extra_headers
|
|
355
|
+
return policy
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
async def main() -> None:
|
|
359
|
+
parser = argparse.ArgumentParser(description="Evaluate math task app policies")
|
|
360
|
+
parser.add_argument("--toml", help="Path to TOML config", default=None)
|
|
361
|
+
parser.add_argument("--use-rollout", action="store_true", help="Use server-side rollout")
|
|
362
|
+
args = parser.parse_args()
|
|
363
|
+
|
|
364
|
+
cfg = _load_config(args.toml)
|
|
365
|
+
task_app_url = (cfg.get("task_app_url") or os.getenv("TASK_APP_URL") or "").rstrip("/")
|
|
366
|
+
if not task_app_url:
|
|
367
|
+
raise RuntimeError("task_app_url missing; set in TOML or export TASK_APP_URL")
|
|
368
|
+
model = cfg.get("model") or os.getenv("EVAL_MODEL") or "groq:qwen-2.5-7b"
|
|
369
|
+
split = cfg.get("split") or os.getenv("EVAL_SPLIT") or DEFAULT_SPLIT
|
|
370
|
+
episodes = int(cfg.get("num_episodes") or os.getenv("NUM_EPISODES") or 50)
|
|
371
|
+
seed_start = int(cfg.get("seed_start") or 0)
|
|
372
|
+
|
|
373
|
+
policy_cfg = _default_policy_cfg(cfg)
|
|
374
|
+
provider_hint = cfg.get("provider") or cfg.get("policy", {}).get("provider") or policy_cfg.get("provider")
|
|
375
|
+
provider = _detect_provider(model, provider_hint)
|
|
376
|
+
policy_cfg.pop("provider", None)
|
|
377
|
+
|
|
378
|
+
api_key = os.getenv("ENVIRONMENT_API_KEY")
|
|
379
|
+
|
|
380
|
+
successes = 0
|
|
381
|
+
failures: Dict[str, int] = {}
|
|
382
|
+
results: List[Dict[str, Any]] = []
|
|
383
|
+
|
|
384
|
+
async with TaskAppClient(task_app_url, api_key=api_key) as client:
|
|
385
|
+
for episode in range(episodes):
|
|
386
|
+
seed = seed_start + episode
|
|
387
|
+
if args.use_rollout:
|
|
388
|
+
data = await eval_via_rollout(
|
|
389
|
+
client,
|
|
390
|
+
run_id=f"eval-{seed}",
|
|
391
|
+
split=split,
|
|
392
|
+
seed=seed,
|
|
393
|
+
model=model,
|
|
394
|
+
policy_cfg={"model": model, **policy_cfg},
|
|
395
|
+
)
|
|
396
|
+
else:
|
|
397
|
+
data = await eval_episode(
|
|
398
|
+
client,
|
|
399
|
+
split=split,
|
|
400
|
+
seed=seed,
|
|
401
|
+
model=model,
|
|
402
|
+
provider=provider,
|
|
403
|
+
policy_cfg={"model": model, **policy_cfg},
|
|
404
|
+
)
|
|
405
|
+
results.append(data)
|
|
406
|
+
if data.get("correct"):
|
|
407
|
+
successes += 1
|
|
408
|
+
status = data.get("status") or "unknown"
|
|
409
|
+
failures[status] = failures.get(status, 0) + (0 if data.get("correct") else 1)
|
|
410
|
+
answer = data.get("answer")
|
|
411
|
+
expected = data.get("expected")
|
|
412
|
+
problem = data.get("problem")
|
|
413
|
+
tool_calls = data.get("tool_calls") or []
|
|
414
|
+
print(
|
|
415
|
+
f"Episode {episode+1}/{episodes} seed={seed} status={status} reward={data.get('reward')}\n"
|
|
416
|
+
f" problem: {problem!r}\n"
|
|
417
|
+
f" tool : {tool_calls!r}\n"
|
|
418
|
+
f" answer : {answer!r}\n expected: {expected!r}",
|
|
419
|
+
flush=True,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
accuracy = successes / max(episodes, 1)
|
|
423
|
+
print("=== Evaluation Summary ===")
|
|
424
|
+
print(f"Task App: {task_app_url}")
|
|
425
|
+
print(f"Model: {model}")
|
|
426
|
+
print(f"Split: {split}")
|
|
427
|
+
print(f"Episodes: {episodes}")
|
|
428
|
+
print(f"Accuracy: {accuracy:.3f}")
|
|
429
|
+
print("Failure breakdown:")
|
|
430
|
+
for status, count in sorted(failures.items(), key=lambda kv: (-kv[1], kv[0])):
|
|
431
|
+
print(f" {status}: {count}")
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
if __name__ == "__main__":
|
|
435
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Submit math RL training jobs via Synth backend."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Dict
|
|
12
|
+
|
|
13
|
+
import requests
|
|
14
|
+
import tomllib
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _load_toml(path: Path) -> Dict[str, Any]:
|
|
18
|
+
if not path.exists():
|
|
19
|
+
print(f"config not found: {path}", file=sys.stderr)
|
|
20
|
+
sys.exit(2)
|
|
21
|
+
with path.open("rb") as fh:
|
|
22
|
+
return tomllib.load(fh)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def main() -> None:
|
|
26
|
+
parser = argparse.ArgumentParser(description="Create math RL job via backend RL endpoint")
|
|
27
|
+
parser.add_argument("--backend", default=os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api"))
|
|
28
|
+
parser.add_argument("--config", required=True, help="Path to RL TOML config")
|
|
29
|
+
parser.add_argument("--task-url", default=os.getenv("TASK_APP_URL", ""), help="Override task service URL")
|
|
30
|
+
parser.add_argument("--idempotency", default=os.getenv("RL_IDEMPOTENCY_KEY", ""), help="Optional Idempotency-Key header")
|
|
31
|
+
args = parser.parse_args()
|
|
32
|
+
|
|
33
|
+
cfg_path = Path(args.config).expanduser()
|
|
34
|
+
cfg = _load_toml(cfg_path)
|
|
35
|
+
|
|
36
|
+
services = cfg.get("services") if isinstance(cfg.get("services"), dict) else {}
|
|
37
|
+
|
|
38
|
+
task_url = (args.task_url or "").strip() or (os.getenv("TASK_APP_URL") or "").strip() or (services.get("task_url") or "").strip()
|
|
39
|
+
if not task_url:
|
|
40
|
+
print("Missing task service URL. Provide --task-url or set TASK_APP_URL or services.task_url in TOML", file=sys.stderr)
|
|
41
|
+
sys.exit(2)
|
|
42
|
+
|
|
43
|
+
model_cfg = cfg.get("model") if isinstance(cfg.get("model"), dict) else {}
|
|
44
|
+
has_source = bool((model_cfg.get("source") or "").strip())
|
|
45
|
+
has_base = bool((model_cfg.get("base") or "").strip())
|
|
46
|
+
if has_source == has_base:
|
|
47
|
+
print("Model section must specify exactly one of [model].source or [model].base", file=sys.stderr)
|
|
48
|
+
sys.exit(2)
|
|
49
|
+
|
|
50
|
+
payload: Dict[str, Any] = {
|
|
51
|
+
"job_type": "rl",
|
|
52
|
+
"compute": cfg.get("compute", {}),
|
|
53
|
+
"data": {
|
|
54
|
+
"endpoint_base_url": task_url,
|
|
55
|
+
"config": cfg,
|
|
56
|
+
},
|
|
57
|
+
"tags": cfg.get("tags", {}),
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
backend = str(args.backend).rstrip("/")
|
|
61
|
+
url = f"{backend}/rl/jobs"
|
|
62
|
+
api_key = (os.getenv("SYNTH_API_KEY") or os.getenv("synth_key") or "").strip()
|
|
63
|
+
if not api_key:
|
|
64
|
+
print("Missing SYNTH_API_KEY in env", file=sys.stderr)
|
|
65
|
+
sys.exit(2)
|
|
66
|
+
|
|
67
|
+
headers = {
|
|
68
|
+
"content-type": "application/json",
|
|
69
|
+
"authorization": f"Bearer {api_key}",
|
|
70
|
+
}
|
|
71
|
+
idem = (args.idempotency or "").strip()
|
|
72
|
+
if idem:
|
|
73
|
+
headers["Idempotency-Key"] = idem
|
|
74
|
+
|
|
75
|
+
print(f"[INFO] POST {url}")
|
|
76
|
+
try:
|
|
77
|
+
preview = {"job_type": payload["job_type"], "data": {"config_keys": list(cfg.keys())}}
|
|
78
|
+
print(f"[INFO] Payload preview: {json.dumps(preview)}")
|
|
79
|
+
except Exception:
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
resp = requests.post(url, headers=headers, json=payload, timeout=120)
|
|
83
|
+
ok = resp.status_code in (200, 201)
|
|
84
|
+
try:
|
|
85
|
+
snippet = resp.json()
|
|
86
|
+
except Exception:
|
|
87
|
+
snippet = resp.text[:300]
|
|
88
|
+
print(f"[INFO] Response: {resp.status_code} {snippet}")
|
|
89
|
+
sys.exit(0 if ok else 1)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == "__main__":
|
|
93
|
+
main()
|
|
94
|
+
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Math Single-Step Task App
|
|
2
|
+
|
|
3
|
+
This directory hosts the legacy entrypoint for the math single-step task app. Prefer starting the app via:
|
|
4
|
+
|
|
5
|
+
```bash
|
|
6
|
+
uvx synth-ai serve math-single-step --env-file examples/rl/.env --port 8101
|
|
7
|
+
```
|
|
8
|
+
|
|
9
|
+
If you need to run it directly (e.g., for Modal `modal deploy` compatibility), use:
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
python examples/rl/task_app/math_task_app.py --env-file examples/rl/.env --port 8101
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
Environment variables:
|
|
16
|
+
|
|
17
|
+
- `MATH_DATASET_NAME` – defaults to `EleutherAI/math`
|
|
18
|
+
- `MATH_DATASET_CONFIG` – defaults to `algebra__linear_1d`
|
|
19
|
+
- `MATH_DATASET_DEFAULT_SPLIT`, `MATH_DATASET_VALIDATION_SPLIT`, `MATH_DATASET_TEST_SPLIT`
|
|
20
|
+
|
|
21
|
+
The task app enforces a single `math_submit` tool call per episode, enabling RL to reward correct final answers and penalise missing or malformed submissions.
|
|
22
|
+
|
|
@@ -16,7 +16,7 @@ from datasets import load_dataset
|
|
|
16
16
|
from fastapi import APIRouter, HTTPException, Request
|
|
17
17
|
from pydantic import BaseModel, Field
|
|
18
18
|
|
|
19
|
-
from
|
|
19
|
+
from synth_ai.task.contracts import (
|
|
20
20
|
RolloutMetrics,
|
|
21
21
|
RolloutRequest,
|
|
22
22
|
RolloutResponse,
|
|
@@ -24,18 +24,18 @@ from ..contracts import (
|
|
|
24
24
|
RolloutTrajectory,
|
|
25
25
|
TaskInfo,
|
|
26
26
|
)
|
|
27
|
-
from
|
|
28
|
-
from
|
|
29
|
-
from
|
|
30
|
-
from
|
|
31
|
-
from
|
|
27
|
+
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
28
|
+
from synth_ai.task.rubrics import Rubric, load_rubric
|
|
29
|
+
from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
30
|
+
from synth_ai.task.errors import http_exception
|
|
31
|
+
from synth_ai.task.tracing_utils import (
|
|
32
32
|
build_tracer_factory,
|
|
33
33
|
resolve_sft_output_dir,
|
|
34
34
|
resolve_tracing_db_url,
|
|
35
35
|
tracing_env_enabled,
|
|
36
36
|
)
|
|
37
|
-
from
|
|
38
|
-
from . import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
37
|
+
from synth_ai.task.vendors import normalize_vendor_keys
|
|
38
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
39
39
|
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
40
40
|
|
|
41
41
|
REPO_ROOT = Path(__file__).resolve().parents[3]
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Legacy entrypoint for the math single-step task app."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from fastapi.exceptions import RequestValidationError
|
|
9
|
+
from fastapi.responses import JSONResponse
|
|
10
|
+
from starlette.requests import Request
|
|
11
|
+
|
|
12
|
+
from synth_ai.task.server import create_task_app, run_task_app
|
|
13
|
+
from .math_single_step import build_config
|
|
14
|
+
from synth_ai.task.auth import is_api_key_header_authorized, normalize_environment_api_key
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def fastapi_app():
|
|
18
|
+
"""Return a FastAPI application for hosting the math task app."""
|
|
19
|
+
|
|
20
|
+
app = create_task_app(build_config())
|
|
21
|
+
|
|
22
|
+
# Replace default health endpoints with auth-tolerant handlers.
|
|
23
|
+
filtered_routes = []
|
|
24
|
+
for route in app.router.routes:
|
|
25
|
+
path = getattr(route, "path", None)
|
|
26
|
+
methods = getattr(route, "methods", set()) or set()
|
|
27
|
+
if path in {"/health", "/health/rollout"} and "GET" in methods:
|
|
28
|
+
continue
|
|
29
|
+
filtered_routes.append(route)
|
|
30
|
+
app.router.routes = filtered_routes
|
|
31
|
+
|
|
32
|
+
def _log_env_key_prefix(source: str, env_key: str | None) -> str | None:
|
|
33
|
+
if not env_key:
|
|
34
|
+
return None
|
|
35
|
+
prefix = env_key[: max(1, len(env_key) // 2)]
|
|
36
|
+
print(f"[{source}] expected ENVIRONMENT_API_KEY prefix: {prefix}")
|
|
37
|
+
return prefix
|
|
38
|
+
|
|
39
|
+
@app.get("/health")
|
|
40
|
+
async def health(request: Request):
|
|
41
|
+
env_key = normalize_environment_api_key()
|
|
42
|
+
if not env_key:
|
|
43
|
+
return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
|
|
44
|
+
if not is_api_key_header_authorized(request):
|
|
45
|
+
prefix = _log_env_key_prefix("health", env_key)
|
|
46
|
+
content = {"status": "healthy", "authorized": False}
|
|
47
|
+
if prefix:
|
|
48
|
+
content["expected_api_key_prefix"] = prefix
|
|
49
|
+
return JSONResponse(status_code=200, content=content)
|
|
50
|
+
return {"status": "healthy", "authorized": True}
|
|
51
|
+
|
|
52
|
+
@app.get("/health/rollout")
|
|
53
|
+
async def health_rollout(request: Request):
|
|
54
|
+
env_key = normalize_environment_api_key()
|
|
55
|
+
if not env_key:
|
|
56
|
+
return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
|
|
57
|
+
if not is_api_key_header_authorized(request):
|
|
58
|
+
prefix = _log_env_key_prefix("health/rollout", env_key)
|
|
59
|
+
content = {"status": "healthy", "authorized": False}
|
|
60
|
+
if prefix:
|
|
61
|
+
content["expected_api_key_prefix"] = prefix
|
|
62
|
+
return JSONResponse(status_code=200, content=content)
|
|
63
|
+
return {"ok": True, "authorized": True}
|
|
64
|
+
|
|
65
|
+
@app.exception_handler(RequestValidationError)
|
|
66
|
+
async def _on_validation_error(request: Request, exc: RequestValidationError):
|
|
67
|
+
try:
|
|
68
|
+
hdr = request.headers
|
|
69
|
+
snapshot = {
|
|
70
|
+
"path": str(getattr(request, "url").path),
|
|
71
|
+
"have_x_api_key": bool(hdr.get("x-api-key")),
|
|
72
|
+
"have_x_api_keys": bool(hdr.get("x-api-keys")),
|
|
73
|
+
"have_authorization": bool(hdr.get("authorization")),
|
|
74
|
+
"errors": exc.errors()[:5],
|
|
75
|
+
}
|
|
76
|
+
print("[422] validation", snapshot, flush=True)
|
|
77
|
+
except Exception:
|
|
78
|
+
pass
|
|
79
|
+
return JSONResponse(status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]})
|
|
80
|
+
|
|
81
|
+
return app
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
if __name__ == "__main__":
|
|
85
|
+
parser = argparse.ArgumentParser(description="Run the math single-step task app locally")
|
|
86
|
+
parser.add_argument("--host", default="0.0.0.0")
|
|
87
|
+
parser.add_argument("--port", type=int, default=8101)
|
|
88
|
+
parser.add_argument("--reload", action="store_true", help="Enable uvicorn autoreload")
|
|
89
|
+
parser.add_argument(
|
|
90
|
+
"--env-file",
|
|
91
|
+
action="append",
|
|
92
|
+
default=[],
|
|
93
|
+
help="Additional .env files to load before startup",
|
|
94
|
+
)
|
|
95
|
+
args = parser.parse_args()
|
|
96
|
+
|
|
97
|
+
default_env = Path(__file__).resolve().parents[2] / ".env"
|
|
98
|
+
env_files = [str(default_env)] if default_env.exists() else []
|
|
99
|
+
env_files.extend(args.env_file or [])
|
|
100
|
+
|
|
101
|
+
run_task_app(
|
|
102
|
+
build_config,
|
|
103
|
+
host=args.host,
|
|
104
|
+
port=args.port,
|
|
105
|
+
reload=args.reload,
|
|
106
|
+
env_files=env_files,
|
|
107
|
+
)
|