synth-ai 0.2.9.dev3__py3-none-any.whl → 0.2.9.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (107) hide show
  1. examples/analyze_semantic_words.sh +17 -0
  2. examples/common_old/backend.py +21 -0
  3. examples/crafter_debug_render.py +180 -0
  4. examples/evals_old/README.md +98 -0
  5. examples/evals_old/__init__.py +6 -0
  6. examples/evals_old/compare_models.py +1037 -0
  7. examples/evals_old/example_log.md +145 -0
  8. examples/evals_old/run_demo.sh +126 -0
  9. examples/evals_old/trace_analysis.py +270 -0
  10. examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
  11. examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
  12. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
  13. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
  14. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
  15. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
  16. examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
  17. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
  18. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
  19. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
  20. examples/finetuning_old/synth_qwen_v1/README.md +68 -0
  21. examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
  22. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
  23. examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
  24. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
  25. examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
  26. examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
  27. examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
  28. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
  29. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
  30. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
  31. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
  32. examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
  33. examples/finetuning_old/synth_qwen_v1/util.py +147 -0
  34. examples/rl/README.md +169 -0
  35. examples/rl/configs/eval_base_qwen.toml +15 -0
  36. examples/rl/configs/eval_rl_qwen.toml +11 -0
  37. examples/rl/configs/rl_from_base_qwen.toml +35 -0
  38. examples/rl/configs/rl_from_base_qwen17.toml +74 -0
  39. examples/rl/configs/rl_from_ft_qwen.toml +35 -0
  40. examples/rl/download_dataset.py +64 -0
  41. examples/rl/run_eval.py +435 -0
  42. examples/rl/run_rl_and_save.py +94 -0
  43. examples/rl/task_app/README.md +22 -0
  44. {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
  45. examples/rl/task_app/math_task_app.py +107 -0
  46. examples/rl_old/task_app.py +962 -0
  47. examples/run_crafter_demo.sh +10 -0
  48. examples/warming_up_to_rl/analyze_trace_db.py +420 -0
  49. examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
  50. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
  51. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
  52. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
  53. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
  54. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
  55. examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
  56. examples/warming_up_to_rl/export_trace_sft.py +541 -0
  57. examples/warming_up_to_rl/groq_test.py +88 -0
  58. examples/warming_up_to_rl/manage_secrets.py +127 -0
  59. examples/warming_up_to_rl/old/event_rewards.md +234 -0
  60. examples/warming_up_to_rl/old/notes.md +73 -0
  61. examples/warming_up_to_rl/readme.md +172 -0
  62. examples/warming_up_to_rl/run_eval.py +434 -0
  63. examples/warming_up_to_rl/run_fft_and_save.py +309 -0
  64. examples/warming_up_to_rl/run_local_rollout.py +188 -0
  65. examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
  66. examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
  67. examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
  68. examples/warming_up_to_rl/run_rl_and_save.py +101 -0
  69. examples/warming_up_to_rl/run_rollout_remote.py +129 -0
  70. examples/warming_up_to_rl/task_app/README.md +38 -0
  71. {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
  72. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
  73. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  74. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  75. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
  76. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
  77. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  78. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
  84. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  85. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
  86. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  87. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
  97. synth_ai/api/train/config_finder.py +18 -18
  98. synth_ai/api/train/env_resolver.py +28 -1
  99. synth_ai/cli/task_apps.py +291 -56
  100. synth_ai/task/apps/__init__.py +54 -13
  101. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/METADATA +1 -1
  102. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/RECORD +106 -13
  103. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/top_level.txt +1 -0
  104. synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
  105. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/WHEEL +0 -0
  106. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/entry_points.txt +0 -0
  107. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,435 @@
1
+ #!/usr/bin/env python3
2
+ """Evaluate math single-step task policies against the task app."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import asyncio
8
+ import json
9
+ import os
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+ import httpx
14
+ import tomllib
15
+
16
+
17
+ class TaskAppClient:
18
+ """Minimal async client for math single-step task app."""
19
+
20
+ def __init__(self, base_url: str, api_key: Optional[str] = None) -> None:
21
+ self.base_url = base_url.rstrip("/")
22
+ self.api_key = api_key
23
+ self._client: Optional[httpx.AsyncClient] = None
24
+
25
+ async def __aenter__(self) -> "TaskAppClient":
26
+ headers = {"X-API-Key": self.api_key} if self.api_key else {}
27
+ self._client = httpx.AsyncClient(
28
+ base_url=self.base_url, headers=headers, timeout=httpx.Timeout(120.0), follow_redirects=True
29
+ )
30
+ return self
31
+
32
+ async def __aexit__(self, exc_type, exc, tb) -> None:
33
+ if self._client:
34
+ await self._client.aclose()
35
+ self._client = None
36
+
37
+ @property
38
+ def client(self) -> httpx.AsyncClient:
39
+ if self._client is None:
40
+ headers = {"X-API-Key": self.api_key} if self.api_key else {}
41
+ self._client = httpx.AsyncClient(
42
+ base_url=self.base_url, headers=headers, timeout=httpx.Timeout(120.0), follow_redirects=True
43
+ )
44
+ return self._client
45
+
46
+ async def initialize(self, split: str, seed: int | None) -> Dict[str, Any]:
47
+ payload: Dict[str, Any] = {"config": {"split": split}}
48
+ if seed is not None:
49
+ payload["seed"] = seed
50
+ resp = await self.client.post("/env/math/initialize", json=payload)
51
+ resp.raise_for_status()
52
+ return resp.json()
53
+
54
+ async def step(self, env_id: str, tool_calls: List[Dict[str, Any]]) -> Dict[str, Any]:
55
+ payload = {"env_id": env_id, "action": {"tool_calls": tool_calls}}
56
+ resp = await self.client.post("/env/math/step", json=payload)
57
+ resp.raise_for_status()
58
+ return resp.json()
59
+
60
+ async def terminate(self, env_id: str) -> None:
61
+ try:
62
+ await self.client.post("/env/math/terminate", json={"env_id": env_id})
63
+ except Exception:
64
+ pass
65
+
66
+ async def get_info(self) -> Dict[str, Any]:
67
+ resp = await self.client.get("/info")
68
+ resp.raise_for_status()
69
+ return resp.json()
70
+
71
+ async def rollout(self, payload: Dict[str, Any]) -> Dict[str, Any]:
72
+ resp = await self.client.post("/rollout", json=payload)
73
+ resp.raise_for_status()
74
+ return resp.json()
75
+
76
+ async def post_inference(
77
+ self,
78
+ url: str,
79
+ payload: Dict[str, Any],
80
+ *,
81
+ headers: Dict[str, str] | None = None,
82
+ ) -> Dict[str, Any]:
83
+ async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as c:
84
+ resp = await c.post(url, json=payload, headers=headers)
85
+ resp.raise_for_status()
86
+ return resp.json()
87
+
88
+
89
+ TOOL_NAME = "math_submit"
90
+ DEFAULT_SPLIT = os.getenv("MATH_EVAL_DEFAULT_SPLIT", "validation")
91
+
92
+
93
+ def _math_tool_schema() -> List[Dict[str, Any]]:
94
+ return [
95
+ {
96
+ "type": "function",
97
+ "function": {
98
+ "name": TOOL_NAME,
99
+ "description": "Submit the final answer for the math problem.",
100
+ "parameters": {
101
+ "type": "object",
102
+ "properties": {
103
+ "answer": {
104
+ "type": "string",
105
+ "description": "Final answer in simplest form",
106
+ }
107
+ ,
108
+ "explanation": {
109
+ "type": "string",
110
+ "description": "Optional explanation of reasoning",
111
+ },
112
+ },
113
+ "required": ["answer"],
114
+ "additionalProperties": False,
115
+ },
116
+ },
117
+ }
118
+ ]
119
+
120
+
121
+ def _build_messages(problem: str) -> List[Dict[str, Any]]:
122
+ return [
123
+ {
124
+ "role": "system",
125
+ "content": (
126
+ "You solve math problems. Always respond with a single math_submit tool call "
127
+ "containing only the final answer."
128
+ ),
129
+ },
130
+ {
131
+ "role": "user",
132
+ "content": f"Problem:\n{problem}\nReturn the final answer via math_submit.",
133
+ },
134
+ ]
135
+
136
+
137
+ def _parse_tool_calls(data: Dict[str, Any]) -> List[Dict[str, Any]]:
138
+ choices = data.get("choices") or []
139
+ if not choices:
140
+ return []
141
+ message = choices[0].get("message") or {}
142
+ raw_calls = message.get("tool_calls") or []
143
+ tool_calls: List[Dict[str, Any]] = []
144
+ for call in raw_calls:
145
+ function = call.get("function") or {}
146
+ name = function.get("name")
147
+ arguments = function.get("arguments")
148
+ parsed_args: Dict[str, Any]
149
+ if isinstance(arguments, str):
150
+ try:
151
+ parsed_args = json.loads(arguments)
152
+ except Exception:
153
+ parsed_args = {}
154
+ elif isinstance(arguments, dict):
155
+ parsed_args = dict(arguments)
156
+ else:
157
+ parsed_args = {}
158
+ tool_calls.append({"tool": name, "args": parsed_args})
159
+ return tool_calls
160
+
161
+
162
+ def _detect_provider(model: str, hint: Optional[str]) -> str:
163
+ if hint:
164
+ return hint.lower()
165
+ lowered = (model or "").lower()
166
+ if lowered.startswith("groq:"):
167
+ return "groq"
168
+ return "generic"
169
+
170
+
171
+ def _resolve_inference_url(base_url: str) -> str:
172
+ normalized = (base_url or "").rstrip("/")
173
+ if not normalized:
174
+ raise RuntimeError("inference_url cannot be empty")
175
+ if normalized.endswith("/v1/chat/completions"):
176
+ return normalized
177
+ if normalized.endswith("/chat/completions"):
178
+ return normalized
179
+ if normalized.endswith("/v1"):
180
+ return f"{normalized}/chat/completions"
181
+ if "/v1/" in normalized:
182
+ return f"{normalized}/chat/completions"
183
+ return f"{normalized}/v1/chat/completions"
184
+
185
+
186
+ async def _choose_actions(
187
+ client: TaskAppClient,
188
+ provider: str,
189
+ model: str,
190
+ problem: str,
191
+ policy_cfg: Dict[str, Any],
192
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
193
+ messages = _build_messages(problem)
194
+ payload: Dict[str, Any] = {
195
+ "model": model,
196
+ "messages": messages,
197
+ "tools": _math_tool_schema(),
198
+ "tool_choice": {"type": "function", "function": {"name": TOOL_NAME}},
199
+ "temperature": policy_cfg.get("temperature", 0.0),
200
+ "top_p": policy_cfg.get("top_p", 1.0),
201
+ "max_tokens": policy_cfg.get("max_tokens", 256),
202
+ }
203
+
204
+ if provider == "groq":
205
+ # Task app proxies Groq requests; reuse existing headers on the client
206
+ response = await client.client.post(
207
+ "/proxy/groq/v1/chat/completions", json=payload
208
+ )
209
+ response.raise_for_status()
210
+ body = response.json()
211
+ else:
212
+ inference_url = policy_cfg.get("inference_url")
213
+ if not inference_url:
214
+ raise RuntimeError("inference_url required for non-groq evaluations")
215
+ headers = dict(policy_cfg.get("headers") or {})
216
+ for key, value in (policy_cfg.get("extra_headers") or {}).items():
217
+ headers.setdefault(key, value)
218
+ final_url = _resolve_inference_url(inference_url)
219
+ try:
220
+ response = await client.client.post(
221
+ final_url,
222
+ json=payload,
223
+ headers=headers or None,
224
+ )
225
+ except httpx.ReadTimeout as exc:
226
+ raise RuntimeError(
227
+ "Inference request timed out. Check the inference service." ) from exc
228
+ try:
229
+ body = response.json()
230
+ except Exception:
231
+ body = {"raw": response.text[:800]}
232
+ if response.status_code >= 500:
233
+ raise RuntimeError(
234
+ f"Inference server error {response.status_code}: {body}")
235
+ if response.status_code >= 400:
236
+ raise RuntimeError(
237
+ f"Inference request invalid ({response.status_code}): {body}")
238
+ tool_calls = _parse_tool_calls(body)
239
+ return tool_calls, body
240
+
241
+
242
+ def _tool_to_answer(tool_calls: List[Dict[str, Any]]) -> str:
243
+ if not tool_calls:
244
+ return ""
245
+ args = tool_calls[0].get("args") or {}
246
+ answer = str(args.get("answer") or "")
247
+ return answer.strip()
248
+
249
+
250
+ async def eval_episode(
251
+ client: TaskAppClient,
252
+ *,
253
+ split: str,
254
+ seed: Optional[int],
255
+ model: str,
256
+ provider: str,
257
+ policy_cfg: Dict[str, Any],
258
+ ) -> Dict[str, Any]:
259
+ created = await client.initialize(split, seed)
260
+ env_id = created["env_id"]
261
+ observation = created.get("observation") or {}
262
+ problem = observation.get("problem") or ""
263
+
264
+ tool_calls, raw_response = await _choose_actions(client, provider, model, problem, policy_cfg)
265
+ answer = _tool_to_answer(tool_calls)
266
+ result = await client.step(env_id, tool_calls)
267
+ await client.terminate(env_id)
268
+
269
+ info = result.get("info") or {}
270
+ reward = result.get("reward") or 0.0
271
+ status = info.get("status") or ("correct" if reward > 0 else "incorrect")
272
+ return {
273
+ "seed": seed,
274
+ "split": split,
275
+ "problem": problem,
276
+ "answer": answer,
277
+ "expected": info.get("expected_answer"),
278
+ "reward": reward,
279
+ "status": status,
280
+ "correct": bool(info.get("correct")),
281
+ "raw_response": raw_response,
282
+ "tool_calls": tool_calls,
283
+ }
284
+
285
+
286
+ async def eval_via_rollout(
287
+ client: TaskAppClient,
288
+ *,
289
+ run_id: str,
290
+ split: str,
291
+ seed: Optional[int],
292
+ model: str,
293
+ policy_cfg: Dict[str, Any],
294
+ ) -> Dict[str, Any]:
295
+ payload = {
296
+ "run_id": run_id,
297
+ "env": {
298
+ "env_name": "math",
299
+ "config": {"split": split},
300
+ "seed": seed,
301
+ },
302
+ "policy": {
303
+ "policy_name": "math-single-step",
304
+ "config": policy_cfg,
305
+ },
306
+ "ops": ["agent", "env"],
307
+ "on_done": "terminate",
308
+ }
309
+ resp = await client.rollout(payload)
310
+ trajs = resp.get("trajectories") or []
311
+ if not trajs:
312
+ return {"reward": 0.0, "correct": False, "status": "missing"}
313
+ traj = trajs[0]
314
+ steps = traj.get("steps") or []
315
+ step = steps[0] if steps else {}
316
+ info = step.get("info") or {}
317
+ return {
318
+ "seed": seed,
319
+ "split": split,
320
+ "problem": observation.get("problem"),
321
+ "answer": _tool_to_answer(step.get("tool_calls") or []),
322
+ "expected": info.get("expected_answer"),
323
+ "reward": step.get("reward") or 0.0,
324
+ "status": info.get("status"),
325
+ "correct": bool(info.get("correct")),
326
+ "raw_response": resp,
327
+ "tool_calls": step.get("tool_calls") or [],
328
+ }
329
+
330
+
331
+ def _load_config(path: Optional[str]) -> Dict[str, Any]:
332
+ if not path:
333
+ return {}
334
+ with open(path, "rb") as fh:
335
+ return tomllib.load(fh)
336
+
337
+
338
+ def _default_policy_cfg(cfg: Dict[str, Any]) -> Dict[str, Any]:
339
+ policy = dict(cfg.get("policy") or {})
340
+ if "inference_url" not in policy:
341
+ env_url = os.getenv("INFERENCE_URL")
342
+ if env_url:
343
+ policy["inference_url"] = env_url
344
+ for key in ("max_tokens", "temperature", "top_p", "headers", "extra_headers"):
345
+ if key not in policy and key in cfg:
346
+ policy[key] = cfg[key]
347
+ extra_headers = dict(policy.get("extra_headers") or {})
348
+ headers = dict(policy.get("headers") or {})
349
+ if "Authorization" not in headers and "Authorization" not in extra_headers:
350
+ synth_key = os.getenv("SYNTH_API_KEY")
351
+ if synth_key:
352
+ extra_headers["Authorization"] = f"Bearer {synth_key}"
353
+ if extra_headers:
354
+ policy["extra_headers"] = extra_headers
355
+ return policy
356
+
357
+
358
+ async def main() -> None:
359
+ parser = argparse.ArgumentParser(description="Evaluate math task app policies")
360
+ parser.add_argument("--toml", help="Path to TOML config", default=None)
361
+ parser.add_argument("--use-rollout", action="store_true", help="Use server-side rollout")
362
+ args = parser.parse_args()
363
+
364
+ cfg = _load_config(args.toml)
365
+ task_app_url = (cfg.get("task_app_url") or os.getenv("TASK_APP_URL") or "").rstrip("/")
366
+ if not task_app_url:
367
+ raise RuntimeError("task_app_url missing; set in TOML or export TASK_APP_URL")
368
+ model = cfg.get("model") or os.getenv("EVAL_MODEL") or "groq:qwen-2.5-7b"
369
+ split = cfg.get("split") or os.getenv("EVAL_SPLIT") or DEFAULT_SPLIT
370
+ episodes = int(cfg.get("num_episodes") or os.getenv("NUM_EPISODES") or 50)
371
+ seed_start = int(cfg.get("seed_start") or 0)
372
+
373
+ policy_cfg = _default_policy_cfg(cfg)
374
+ provider_hint = cfg.get("provider") or cfg.get("policy", {}).get("provider") or policy_cfg.get("provider")
375
+ provider = _detect_provider(model, provider_hint)
376
+ policy_cfg.pop("provider", None)
377
+
378
+ api_key = os.getenv("ENVIRONMENT_API_KEY")
379
+
380
+ successes = 0
381
+ failures: Dict[str, int] = {}
382
+ results: List[Dict[str, Any]] = []
383
+
384
+ async with TaskAppClient(task_app_url, api_key=api_key) as client:
385
+ for episode in range(episodes):
386
+ seed = seed_start + episode
387
+ if args.use_rollout:
388
+ data = await eval_via_rollout(
389
+ client,
390
+ run_id=f"eval-{seed}",
391
+ split=split,
392
+ seed=seed,
393
+ model=model,
394
+ policy_cfg={"model": model, **policy_cfg},
395
+ )
396
+ else:
397
+ data = await eval_episode(
398
+ client,
399
+ split=split,
400
+ seed=seed,
401
+ model=model,
402
+ provider=provider,
403
+ policy_cfg={"model": model, **policy_cfg},
404
+ )
405
+ results.append(data)
406
+ if data.get("correct"):
407
+ successes += 1
408
+ status = data.get("status") or "unknown"
409
+ failures[status] = failures.get(status, 0) + (0 if data.get("correct") else 1)
410
+ answer = data.get("answer")
411
+ expected = data.get("expected")
412
+ problem = data.get("problem")
413
+ tool_calls = data.get("tool_calls") or []
414
+ print(
415
+ f"Episode {episode+1}/{episodes} seed={seed} status={status} reward={data.get('reward')}\n"
416
+ f" problem: {problem!r}\n"
417
+ f" tool : {tool_calls!r}\n"
418
+ f" answer : {answer!r}\n expected: {expected!r}",
419
+ flush=True,
420
+ )
421
+
422
+ accuracy = successes / max(episodes, 1)
423
+ print("=== Evaluation Summary ===")
424
+ print(f"Task App: {task_app_url}")
425
+ print(f"Model: {model}")
426
+ print(f"Split: {split}")
427
+ print(f"Episodes: {episodes}")
428
+ print(f"Accuracy: {accuracy:.3f}")
429
+ print("Failure breakdown:")
430
+ for status, count in sorted(failures.items(), key=lambda kv: (-kv[1], kv[0])):
431
+ print(f" {status}: {count}")
432
+
433
+
434
+ if __name__ == "__main__":
435
+ asyncio.run(main())
@@ -0,0 +1,94 @@
1
+ #!/usr/bin/env python3
2
+ """Submit math RL training jobs via Synth backend."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ import sys
10
+ from pathlib import Path
11
+ from typing import Any, Dict
12
+
13
+ import requests
14
+ import tomllib
15
+
16
+
17
+ def _load_toml(path: Path) -> Dict[str, Any]:
18
+ if not path.exists():
19
+ print(f"config not found: {path}", file=sys.stderr)
20
+ sys.exit(2)
21
+ with path.open("rb") as fh:
22
+ return tomllib.load(fh)
23
+
24
+
25
+ def main() -> None:
26
+ parser = argparse.ArgumentParser(description="Create math RL job via backend RL endpoint")
27
+ parser.add_argument("--backend", default=os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api"))
28
+ parser.add_argument("--config", required=True, help="Path to RL TOML config")
29
+ parser.add_argument("--task-url", default=os.getenv("TASK_APP_URL", ""), help="Override task service URL")
30
+ parser.add_argument("--idempotency", default=os.getenv("RL_IDEMPOTENCY_KEY", ""), help="Optional Idempotency-Key header")
31
+ args = parser.parse_args()
32
+
33
+ cfg_path = Path(args.config).expanduser()
34
+ cfg = _load_toml(cfg_path)
35
+
36
+ services = cfg.get("services") if isinstance(cfg.get("services"), dict) else {}
37
+
38
+ task_url = (args.task_url or "").strip() or (os.getenv("TASK_APP_URL") or "").strip() or (services.get("task_url") or "").strip()
39
+ if not task_url:
40
+ print("Missing task service URL. Provide --task-url or set TASK_APP_URL or services.task_url in TOML", file=sys.stderr)
41
+ sys.exit(2)
42
+
43
+ model_cfg = cfg.get("model") if isinstance(cfg.get("model"), dict) else {}
44
+ has_source = bool((model_cfg.get("source") or "").strip())
45
+ has_base = bool((model_cfg.get("base") or "").strip())
46
+ if has_source == has_base:
47
+ print("Model section must specify exactly one of [model].source or [model].base", file=sys.stderr)
48
+ sys.exit(2)
49
+
50
+ payload: Dict[str, Any] = {
51
+ "job_type": "rl",
52
+ "compute": cfg.get("compute", {}),
53
+ "data": {
54
+ "endpoint_base_url": task_url,
55
+ "config": cfg,
56
+ },
57
+ "tags": cfg.get("tags", {}),
58
+ }
59
+
60
+ backend = str(args.backend).rstrip("/")
61
+ url = f"{backend}/rl/jobs"
62
+ api_key = (os.getenv("SYNTH_API_KEY") or os.getenv("synth_key") or "").strip()
63
+ if not api_key:
64
+ print("Missing SYNTH_API_KEY in env", file=sys.stderr)
65
+ sys.exit(2)
66
+
67
+ headers = {
68
+ "content-type": "application/json",
69
+ "authorization": f"Bearer {api_key}",
70
+ }
71
+ idem = (args.idempotency or "").strip()
72
+ if idem:
73
+ headers["Idempotency-Key"] = idem
74
+
75
+ print(f"[INFO] POST {url}")
76
+ try:
77
+ preview = {"job_type": payload["job_type"], "data": {"config_keys": list(cfg.keys())}}
78
+ print(f"[INFO] Payload preview: {json.dumps(preview)}")
79
+ except Exception:
80
+ pass
81
+
82
+ resp = requests.post(url, headers=headers, json=payload, timeout=120)
83
+ ok = resp.status_code in (200, 201)
84
+ try:
85
+ snippet = resp.json()
86
+ except Exception:
87
+ snippet = resp.text[:300]
88
+ print(f"[INFO] Response: {resp.status_code} {snippet}")
89
+ sys.exit(0 if ok else 1)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
94
+
@@ -0,0 +1,22 @@
1
+ # Math Single-Step Task App
2
+
3
+ This directory hosts the legacy entrypoint for the math single-step task app. Prefer starting the app via:
4
+
5
+ ```bash
6
+ uvx synth-ai serve math-single-step --env-file examples/rl/.env --port 8101
7
+ ```
8
+
9
+ If you need to run it directly (e.g., for Modal `modal deploy` compatibility), use:
10
+
11
+ ```bash
12
+ python examples/rl/task_app/math_task_app.py --env-file examples/rl/.env --port 8101
13
+ ```
14
+
15
+ Environment variables:
16
+
17
+ - `MATH_DATASET_NAME` – defaults to `EleutherAI/math`
18
+ - `MATH_DATASET_CONFIG` – defaults to `algebra__linear_1d`
19
+ - `MATH_DATASET_DEFAULT_SPLIT`, `MATH_DATASET_VALIDATION_SPLIT`, `MATH_DATASET_TEST_SPLIT`
20
+
21
+ The task app enforces a single `math_submit` tool call per episode, enabling RL to reward correct final answers and penalise missing or malformed submissions.
22
+
@@ -16,7 +16,7 @@ from datasets import load_dataset
16
16
  from fastapi import APIRouter, HTTPException, Request
17
17
  from pydantic import BaseModel, Field
18
18
 
19
- from ..contracts import (
19
+ from synth_ai.task.contracts import (
20
20
  RolloutMetrics,
21
21
  RolloutRequest,
22
22
  RolloutResponse,
@@ -24,18 +24,18 @@ from ..contracts import (
24
24
  RolloutTrajectory,
25
25
  TaskInfo,
26
26
  )
27
- from ..datasets import TaskDatasetRegistry, TaskDatasetSpec
28
- from ..rubrics import Rubric, load_rubric
29
- from ..server import ProxyConfig, RubricBundle, TaskAppConfig
30
- from ..errors import http_exception
31
- from ..tracing_utils import (
27
+ from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
28
+ from synth_ai.task.rubrics import Rubric, load_rubric
29
+ from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
30
+ from synth_ai.task.errors import http_exception
31
+ from synth_ai.task.tracing_utils import (
32
32
  build_tracer_factory,
33
33
  resolve_sft_output_dir,
34
34
  resolve_tracing_db_url,
35
35
  tracing_env_enabled,
36
36
  )
37
- from ..vendors import normalize_vendor_keys
38
- from . import ModalDeploymentConfig, TaskAppEntry, register_task_app
37
+ from synth_ai.task.vendors import normalize_vendor_keys
38
+ from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
39
39
  from synth_ai.tracing_v3.session_tracer import SessionTracer
40
40
 
41
41
  REPO_ROOT = Path(__file__).resolve().parents[3]
@@ -0,0 +1,107 @@
1
+ """Legacy entrypoint for the math single-step task app."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+
8
+ from fastapi.exceptions import RequestValidationError
9
+ from fastapi.responses import JSONResponse
10
+ from starlette.requests import Request
11
+
12
+ from synth_ai.task.server import create_task_app, run_task_app
13
+ from .math_single_step import build_config
14
+ from synth_ai.task.auth import is_api_key_header_authorized, normalize_environment_api_key
15
+
16
+
17
+ def fastapi_app():
18
+ """Return a FastAPI application for hosting the math task app."""
19
+
20
+ app = create_task_app(build_config())
21
+
22
+ # Replace default health endpoints with auth-tolerant handlers.
23
+ filtered_routes = []
24
+ for route in app.router.routes:
25
+ path = getattr(route, "path", None)
26
+ methods = getattr(route, "methods", set()) or set()
27
+ if path in {"/health", "/health/rollout"} and "GET" in methods:
28
+ continue
29
+ filtered_routes.append(route)
30
+ app.router.routes = filtered_routes
31
+
32
+ def _log_env_key_prefix(source: str, env_key: str | None) -> str | None:
33
+ if not env_key:
34
+ return None
35
+ prefix = env_key[: max(1, len(env_key) // 2)]
36
+ print(f"[{source}] expected ENVIRONMENT_API_KEY prefix: {prefix}")
37
+ return prefix
38
+
39
+ @app.get("/health")
40
+ async def health(request: Request):
41
+ env_key = normalize_environment_api_key()
42
+ if not env_key:
43
+ return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
44
+ if not is_api_key_header_authorized(request):
45
+ prefix = _log_env_key_prefix("health", env_key)
46
+ content = {"status": "healthy", "authorized": False}
47
+ if prefix:
48
+ content["expected_api_key_prefix"] = prefix
49
+ return JSONResponse(status_code=200, content=content)
50
+ return {"status": "healthy", "authorized": True}
51
+
52
+ @app.get("/health/rollout")
53
+ async def health_rollout(request: Request):
54
+ env_key = normalize_environment_api_key()
55
+ if not env_key:
56
+ return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
57
+ if not is_api_key_header_authorized(request):
58
+ prefix = _log_env_key_prefix("health/rollout", env_key)
59
+ content = {"status": "healthy", "authorized": False}
60
+ if prefix:
61
+ content["expected_api_key_prefix"] = prefix
62
+ return JSONResponse(status_code=200, content=content)
63
+ return {"ok": True, "authorized": True}
64
+
65
+ @app.exception_handler(RequestValidationError)
66
+ async def _on_validation_error(request: Request, exc: RequestValidationError):
67
+ try:
68
+ hdr = request.headers
69
+ snapshot = {
70
+ "path": str(getattr(request, "url").path),
71
+ "have_x_api_key": bool(hdr.get("x-api-key")),
72
+ "have_x_api_keys": bool(hdr.get("x-api-keys")),
73
+ "have_authorization": bool(hdr.get("authorization")),
74
+ "errors": exc.errors()[:5],
75
+ }
76
+ print("[422] validation", snapshot, flush=True)
77
+ except Exception:
78
+ pass
79
+ return JSONResponse(status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]})
80
+
81
+ return app
82
+
83
+
84
+ if __name__ == "__main__":
85
+ parser = argparse.ArgumentParser(description="Run the math single-step task app locally")
86
+ parser.add_argument("--host", default="0.0.0.0")
87
+ parser.add_argument("--port", type=int, default=8101)
88
+ parser.add_argument("--reload", action="store_true", help="Enable uvicorn autoreload")
89
+ parser.add_argument(
90
+ "--env-file",
91
+ action="append",
92
+ default=[],
93
+ help="Additional .env files to load before startup",
94
+ )
95
+ args = parser.parse_args()
96
+
97
+ default_env = Path(__file__).resolve().parents[2] / ".env"
98
+ env_files = [str(default_env)] if default_env.exists() else []
99
+ env_files.extend(args.env_file or [])
100
+
101
+ run_task_app(
102
+ build_config,
103
+ host=args.host,
104
+ port=args.port,
105
+ reload=args.reload,
106
+ env_files=env_files,
107
+ )