synth-ai 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/README.md +1 -0
- examples/multi_step/SFT_README.md +147 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +9 -9
- examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
- examples/multi_step/convert_traces_to_sft.py +84 -0
- examples/multi_step/run_sft_qwen30b.sh +45 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +2 -1
- examples/qwen_coder/configs/coder_lora_4b.toml +2 -1
- examples/qwen_coder/configs/coder_lora_small.toml +2 -1
- examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
- examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
- examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
- examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
- examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
- examples/qwen_vl/QUICKSTART.md +327 -0
- examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
- examples/qwen_vl/README.md +154 -0
- examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
- examples/qwen_vl/RL_VISION_TESTING.md +333 -0
- examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
- examples/qwen_vl/SETUP_COMPLETE.md +275 -0
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +490 -0
- examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
- examples/qwen_vl/__init__.py +2 -0
- examples/qwen_vl/collect_data_via_cli.md +423 -0
- examples/qwen_vl/collect_vision_traces.py +368 -0
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +127 -0
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +60 -0
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +43 -0
- examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +45 -0
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +44 -0
- examples/qwen_vl/configs/filter_qwen2vl_sft.toml +50 -0
- examples/qwen_vl/configs/filter_vision_sft.toml +53 -0
- examples/qwen_vl/configs/filter_vision_test.toml +8 -0
- examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
- examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
- examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
- examples/qwen_vl/run_vision_comparison.sh +62 -0
- examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
- examples/qwen_vl/test_image_validation.py +201 -0
- examples/qwen_vl/test_sft_vision_data.py +110 -0
- examples/rl/README.md +1 -1
- examples/rl/configs/eval_base_qwen.toml +17 -0
- examples/rl/configs/eval_rl_qwen.toml +13 -0
- examples/rl/configs/rl_from_base_qwen.toml +37 -0
- examples/rl/configs/rl_from_base_qwen17.toml +76 -0
- examples/rl/configs/rl_from_ft_qwen.toml +37 -0
- examples/rl/run_eval.py +436 -0
- examples/rl/run_rl_and_save.py +111 -0
- examples/rl/task_app/README.md +22 -0
- examples/rl/task_app/math_single_step.py +990 -0
- examples/rl/task_app/math_task_app.py +111 -0
- examples/sft/README.md +5 -5
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -2
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -3
- examples/sft/evaluate.py +2 -4
- examples/sft/export_dataset.py +7 -4
- examples/swe/task_app/README.md +1 -1
- examples/swe/task_app/grpo_swe_mini.py +0 -1
- examples/swe/task_app/grpo_swe_mini_task_app.py +0 -12
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +13 -13
- examples/swe/task_app/hosted/policy_routes.py +0 -2
- examples/swe/task_app/hosted/rollout.py +0 -8
- examples/task_apps/crafter/task_app/grpo_crafter.py +4 -7
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +59 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +30 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +62 -31
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +16 -14
- examples/task_apps/enron/__init__.py +1 -0
- examples/vlm/README.md +3 -3
- examples/vlm/configs/crafter_vlm_gpt4o.toml +2 -0
- examples/vlm/crafter_openai_vlm_agent.py +3 -5
- examples/vlm/filter_image_rows.py +1 -1
- examples/vlm/run_crafter_vlm_benchmark.py +2 -2
- examples/warming_up_to_rl/_utils.py +92 -0
- examples/warming_up_to_rl/analyze_trace_db.py +1 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +2 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_ft.toml +2 -0
- examples/warming_up_to_rl/export_trace_sft.py +174 -60
- examples/warming_up_to_rl/readme.md +63 -132
- examples/warming_up_to_rl/run_fft_and_save.py +1 -1
- examples/warming_up_to_rl/run_rl_and_save.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +42 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +696 -0
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +478 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +204 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +618 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1081 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1861 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +62 -0
- synth_ai/__init__.py +44 -30
- synth_ai/_utils/__init__.py +47 -0
- synth_ai/_utils/base_url.py +10 -0
- synth_ai/_utils/http.py +10 -0
- synth_ai/_utils/prompts.py +10 -0
- synth_ai/_utils/task_app_state.py +12 -0
- synth_ai/_utils/user_config.py +10 -0
- synth_ai/api/models/supported.py +144 -7
- synth_ai/api/train/__init__.py +13 -1
- synth_ai/api/train/cli.py +30 -7
- synth_ai/api/train/config_finder.py +18 -11
- synth_ai/api/train/env_resolver.py +13 -10
- synth_ai/cli/__init__.py +62 -78
- synth_ai/cli/_modal_wrapper.py +7 -5
- synth_ai/cli/_typer_patch.py +0 -2
- synth_ai/cli/_validate_task_app.py +22 -4
- synth_ai/cli/legacy_root_backup.py +3 -1
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/task_app_discovery.py +7 -0
- synth_ai/cli/lib/task_app_env.py +518 -0
- synth_ai/cli/recent.py +2 -1
- synth_ai/cli/setup.py +266 -0
- synth_ai/cli/status.py +1 -1
- synth_ai/cli/task_app_deploy.py +16 -0
- synth_ai/cli/task_app_list.py +25 -0
- synth_ai/cli/task_app_modal_serve.py +16 -0
- synth_ai/cli/task_app_serve.py +18 -0
- synth_ai/cli/task_apps.py +71 -31
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/train.py +18 -0
- synth_ai/cli/tui.py +7 -2
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/__init__.py +10 -0
- synth_ai/demos/core/__init__.py +28 -1
- synth_ai/demos/crafter/__init__.py +1 -0
- synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/demos/demo_registry.py +176 -0
- synth_ai/demos/math/__init__.py +1 -0
- synth_ai/demos/math/_common.py +16 -0
- synth_ai/demos/math/app.py +38 -0
- synth_ai/demos/math/config.toml +76 -0
- synth_ai/demos/math/deploy_modal.py +54 -0
- synth_ai/demos/math/modal_task_app.py +702 -0
- synth_ai/demos/math/task_app_entry.py +51 -0
- synth_ai/environments/environment/core.py +7 -1
- synth_ai/environments/examples/bandit/engine.py +0 -1
- synth_ai/environments/examples/bandit/environment.py +0 -1
- synth_ai/environments/examples/wordle/environment.py +0 -1
- synth_ai/evals/base.py +16 -5
- synth_ai/evals/client.py +1 -1
- synth_ai/inference/client.py +1 -1
- synth_ai/judge_schemas.py +8 -8
- synth_ai/learning/client.py +1 -1
- synth_ai/learning/health.py +1 -1
- synth_ai/learning/jobs.py +1 -1
- synth_ai/learning/rl/client.py +1 -1
- synth_ai/learning/rl/env_keys.py +1 -1
- synth_ai/learning/rl/secrets.py +1 -1
- synth_ai/learning/sft/client.py +1 -1
- synth_ai/learning/sft/data.py +407 -4
- synth_ai/learning/validators.py +4 -1
- synth_ai/task/apps/__init__.py +4 -2
- synth_ai/task/config.py +6 -4
- synth_ai/task/rubrics/__init__.py +1 -2
- synth_ai/task/rubrics/loaders.py +14 -10
- synth_ai/task/rubrics.py +219 -0
- synth_ai/task/trace_correlation_helpers.py +24 -11
- synth_ai/task/tracing_utils.py +14 -3
- synth_ai/task/validators.py +2 -3
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/config.py +15 -13
- synth_ai/tracing_v3/constants.py +21 -0
- synth_ai/tracing_v3/db_config.py +3 -1
- synth_ai/tracing_v3/decorators.py +10 -7
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/session_tracer.py +7 -7
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +8 -9
- synth_ai/tracing_v3/turso/native_manager.py +80 -72
- synth_ai/tracing_v3/utils.py +2 -2
- synth_ai/tui/cli/query_experiments.py +4 -4
- synth_ai/tui/cli/query_experiments_v3.py +4 -4
- synth_ai/tui/dashboard.py +14 -9
- synth_ai/utils/__init__.py +101 -0
- synth_ai/utils/base_url.py +94 -0
- synth_ai/utils/cli.py +131 -0
- synth_ai/utils/env.py +287 -0
- synth_ai/utils/http.py +169 -0
- synth_ai/utils/modal.py +308 -0
- synth_ai/utils/process.py +212 -0
- synth_ai/utils/prompts.py +39 -0
- synth_ai/utils/sqld.py +122 -0
- synth_ai/utils/task_app_discovery.py +882 -0
- synth_ai/utils/task_app_env.py +186 -0
- synth_ai/utils/task_app_state.py +318 -0
- synth_ai/utils/user_config.py +137 -0
- synth_ai/v0/config/__init__.py +1 -5
- synth_ai/v0/config/base_url.py +1 -7
- synth_ai/v0/tracing/config.py +1 -1
- synth_ai/v0/tracing/decorators.py +1 -1
- synth_ai/v0/tracing/upload.py +1 -1
- synth_ai/v0/tracing_v1/config.py +1 -1
- synth_ai/v0/tracing_v1/decorators.py +1 -1
- synth_ai/v0/tracing_v1/upload.py +1 -1
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/METADATA +85 -31
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/RECORD +229 -117
- synth_ai/cli/man.py +0 -106
- synth_ai/compound/cais.py +0 -0
- synth_ai/core/experiment.py +0 -13
- synth_ai/core/system.py +0 -15
- synth_ai/demo_registry.py +0 -295
- synth_ai/handshake.py +0 -109
- synth_ai/http.py +0 -26
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/top_level.txt +0 -0
examples/rl/run_eval.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Evaluate math single-step task policies against the task app."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import asyncio
|
|
8
|
+
import contextlib
|
|
9
|
+
import json
|
|
10
|
+
import os
|
|
11
|
+
import tomllib
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TaskAppClient:
|
|
18
|
+
"""Minimal async client for math single-step task app."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, base_url: str, api_key: str | None = None) -> None:
|
|
21
|
+
self.base_url = base_url.rstrip("/")
|
|
22
|
+
self.api_key = api_key
|
|
23
|
+
self._client: httpx.AsyncClient | None = None
|
|
24
|
+
|
|
25
|
+
async def __aenter__(self) -> TaskAppClient:
|
|
26
|
+
headers = {"X-API-Key": self.api_key} if self.api_key else {}
|
|
27
|
+
self._client = httpx.AsyncClient(
|
|
28
|
+
base_url=self.base_url,
|
|
29
|
+
headers=headers,
|
|
30
|
+
timeout=httpx.Timeout(120.0),
|
|
31
|
+
follow_redirects=True,
|
|
32
|
+
)
|
|
33
|
+
return self
|
|
34
|
+
|
|
35
|
+
async def __aexit__(self, exc_type, exc, tb) -> None:
|
|
36
|
+
if self._client:
|
|
37
|
+
await self._client.aclose()
|
|
38
|
+
self._client = None
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def client(self) -> httpx.AsyncClient:
|
|
42
|
+
if self._client is None:
|
|
43
|
+
headers = {"X-API-Key": self.api_key} if self.api_key else {}
|
|
44
|
+
self._client = httpx.AsyncClient(
|
|
45
|
+
base_url=self.base_url,
|
|
46
|
+
headers=headers,
|
|
47
|
+
timeout=httpx.Timeout(120.0),
|
|
48
|
+
follow_redirects=True,
|
|
49
|
+
)
|
|
50
|
+
return self._client
|
|
51
|
+
|
|
52
|
+
async def initialize(self, split: str, seed: int | None) -> dict[str, Any]:
|
|
53
|
+
payload: dict[str, Any] = {"config": {"split": split}}
|
|
54
|
+
if seed is not None:
|
|
55
|
+
payload["seed"] = seed
|
|
56
|
+
resp = await self.client.post("/env/math/initialize", json=payload)
|
|
57
|
+
resp.raise_for_status()
|
|
58
|
+
return resp.json()
|
|
59
|
+
|
|
60
|
+
async def step(self, env_id: str, tool_calls: list[dict[str, Any]]) -> dict[str, Any]:
|
|
61
|
+
payload = {"env_id": env_id, "action": {"tool_calls": tool_calls}}
|
|
62
|
+
resp = await self.client.post("/env/math/step", json=payload)
|
|
63
|
+
resp.raise_for_status()
|
|
64
|
+
return resp.json()
|
|
65
|
+
|
|
66
|
+
async def terminate(self, env_id: str) -> None:
|
|
67
|
+
with contextlib.suppress(Exception):
|
|
68
|
+
await self.client.post("/env/math/terminate", json={"env_id": env_id})
|
|
69
|
+
|
|
70
|
+
async def get_info(self) -> dict[str, Any]:
|
|
71
|
+
resp = await self.client.get("/info")
|
|
72
|
+
resp.raise_for_status()
|
|
73
|
+
return resp.json()
|
|
74
|
+
|
|
75
|
+
async def rollout(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
76
|
+
resp = await self.client.post("/rollout", json=payload)
|
|
77
|
+
resp.raise_for_status()
|
|
78
|
+
return resp.json()
|
|
79
|
+
|
|
80
|
+
async def post_inference(
|
|
81
|
+
self,
|
|
82
|
+
url: str,
|
|
83
|
+
payload: dict[str, Any],
|
|
84
|
+
*,
|
|
85
|
+
headers: dict[str, str] | None = None,
|
|
86
|
+
) -> dict[str, Any]:
|
|
87
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as c:
|
|
88
|
+
resp = await c.post(url, json=payload, headers=headers)
|
|
89
|
+
resp.raise_for_status()
|
|
90
|
+
return resp.json()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
TOOL_NAME = "math_submit"
|
|
94
|
+
DEFAULT_SPLIT = os.getenv("MATH_EVAL_DEFAULT_SPLIT", "validation")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _math_tool_schema() -> list[dict[str, Any]]:
|
|
98
|
+
return [
|
|
99
|
+
{
|
|
100
|
+
"type": "function",
|
|
101
|
+
"function": {
|
|
102
|
+
"name": TOOL_NAME,
|
|
103
|
+
"description": "Submit the final answer for the math problem.",
|
|
104
|
+
"parameters": {
|
|
105
|
+
"type": "object",
|
|
106
|
+
"properties": {
|
|
107
|
+
"answer": {
|
|
108
|
+
"type": "string",
|
|
109
|
+
"description": "Final answer in simplest form",
|
|
110
|
+
},
|
|
111
|
+
"explanation": {
|
|
112
|
+
"type": "string",
|
|
113
|
+
"description": "Optional explanation of reasoning",
|
|
114
|
+
},
|
|
115
|
+
},
|
|
116
|
+
"required": ["answer"],
|
|
117
|
+
"additionalProperties": False,
|
|
118
|
+
},
|
|
119
|
+
},
|
|
120
|
+
}
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _build_messages(problem: str) -> list[dict[str, Any]]:
|
|
125
|
+
return [
|
|
126
|
+
{
|
|
127
|
+
"role": "system",
|
|
128
|
+
"content": (
|
|
129
|
+
"You solve math problems. Always respond with a single math_submit tool call "
|
|
130
|
+
"containing only the final answer."
|
|
131
|
+
),
|
|
132
|
+
},
|
|
133
|
+
{
|
|
134
|
+
"role": "user",
|
|
135
|
+
"content": f"Problem:\n{problem}\nReturn the final answer via math_submit.",
|
|
136
|
+
},
|
|
137
|
+
]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _parse_tool_calls(data: dict[str, Any]) -> list[dict[str, Any]]:
|
|
141
|
+
choices = data.get("choices") or []
|
|
142
|
+
if not choices:
|
|
143
|
+
return []
|
|
144
|
+
message = choices[0].get("message") or {}
|
|
145
|
+
raw_calls = message.get("tool_calls") or []
|
|
146
|
+
tool_calls: list[dict[str, Any]] = []
|
|
147
|
+
for call in raw_calls:
|
|
148
|
+
function = call.get("function") or {}
|
|
149
|
+
name = function.get("name")
|
|
150
|
+
arguments = function.get("arguments")
|
|
151
|
+
parsed_args: dict[str, Any]
|
|
152
|
+
if isinstance(arguments, str):
|
|
153
|
+
try:
|
|
154
|
+
parsed_args = json.loads(arguments)
|
|
155
|
+
except Exception:
|
|
156
|
+
parsed_args = {}
|
|
157
|
+
elif isinstance(arguments, dict):
|
|
158
|
+
parsed_args = dict(arguments)
|
|
159
|
+
else:
|
|
160
|
+
parsed_args = {}
|
|
161
|
+
tool_calls.append({"tool": name, "args": parsed_args})
|
|
162
|
+
return tool_calls
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _detect_provider(model: str, hint: str | None) -> str:
|
|
166
|
+
if hint:
|
|
167
|
+
return hint.lower()
|
|
168
|
+
lowered = (model or "").lower()
|
|
169
|
+
if lowered.startswith("groq:"):
|
|
170
|
+
return "groq"
|
|
171
|
+
return "generic"
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _resolve_inference_url(base_url: str) -> str:
|
|
175
|
+
normalized = (base_url or "").rstrip("/")
|
|
176
|
+
if not normalized:
|
|
177
|
+
raise RuntimeError("inference_url cannot be empty")
|
|
178
|
+
if normalized.endswith("/v1/chat/completions"):
|
|
179
|
+
return normalized
|
|
180
|
+
if normalized.endswith("/chat/completions"):
|
|
181
|
+
return normalized
|
|
182
|
+
if normalized.endswith("/v1"):
|
|
183
|
+
return f"{normalized}/chat/completions"
|
|
184
|
+
if "/v1/" in normalized:
|
|
185
|
+
return f"{normalized}/chat/completions"
|
|
186
|
+
return f"{normalized}/v1/chat/completions"
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
async def _choose_actions(
|
|
190
|
+
client: TaskAppClient,
|
|
191
|
+
provider: str,
|
|
192
|
+
model: str,
|
|
193
|
+
problem: str,
|
|
194
|
+
policy_cfg: dict[str, Any],
|
|
195
|
+
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
|
196
|
+
messages = _build_messages(problem)
|
|
197
|
+
payload: dict[str, Any] = {
|
|
198
|
+
"model": model,
|
|
199
|
+
"messages": messages,
|
|
200
|
+
"tools": _math_tool_schema(),
|
|
201
|
+
"tool_choice": {"type": "function", "function": {"name": TOOL_NAME}},
|
|
202
|
+
"temperature": policy_cfg.get("temperature", 0.0),
|
|
203
|
+
"top_p": policy_cfg.get("top_p", 1.0),
|
|
204
|
+
"max_tokens": policy_cfg.get("max_tokens", 256),
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
if provider == "groq":
|
|
208
|
+
# Task app proxies Groq requests; reuse existing headers on the client
|
|
209
|
+
response = await client.client.post("/proxy/groq/v1/chat/completions", json=payload)
|
|
210
|
+
response.raise_for_status()
|
|
211
|
+
body = response.json()
|
|
212
|
+
else:
|
|
213
|
+
inference_url = policy_cfg.get("inference_url")
|
|
214
|
+
if not inference_url:
|
|
215
|
+
raise RuntimeError("inference_url required for non-groq evaluations")
|
|
216
|
+
headers = dict(policy_cfg.get("headers") or {})
|
|
217
|
+
for key, value in (policy_cfg.get("extra_headers") or {}).items():
|
|
218
|
+
headers.setdefault(key, value)
|
|
219
|
+
final_url = _resolve_inference_url(inference_url)
|
|
220
|
+
try:
|
|
221
|
+
response = await client.client.post(
|
|
222
|
+
final_url,
|
|
223
|
+
json=payload,
|
|
224
|
+
headers=headers or None,
|
|
225
|
+
)
|
|
226
|
+
except httpx.ReadTimeout as exc:
|
|
227
|
+
raise RuntimeError("Inference request timed out. Check the inference service.") from exc
|
|
228
|
+
try:
|
|
229
|
+
body = response.json()
|
|
230
|
+
except Exception:
|
|
231
|
+
body = {"raw": response.text[:800]}
|
|
232
|
+
if response.status_code >= 500:
|
|
233
|
+
raise RuntimeError(f"Inference server error {response.status_code}: {body}")
|
|
234
|
+
if response.status_code >= 400:
|
|
235
|
+
raise RuntimeError(f"Inference request invalid ({response.status_code}): {body}")
|
|
236
|
+
tool_calls = _parse_tool_calls(body)
|
|
237
|
+
return tool_calls, body
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _tool_to_answer(tool_calls: list[dict[str, Any]]) -> str:
|
|
241
|
+
if not tool_calls:
|
|
242
|
+
return ""
|
|
243
|
+
args = tool_calls[0].get("args") or {}
|
|
244
|
+
answer = str(args.get("answer") or "")
|
|
245
|
+
return answer.strip()
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
async def eval_episode(
|
|
249
|
+
client: TaskAppClient,
|
|
250
|
+
*,
|
|
251
|
+
split: str,
|
|
252
|
+
seed: int | None,
|
|
253
|
+
model: str,
|
|
254
|
+
provider: str,
|
|
255
|
+
policy_cfg: dict[str, Any],
|
|
256
|
+
) -> dict[str, Any]:
|
|
257
|
+
created = await client.initialize(split, seed)
|
|
258
|
+
env_id = created["env_id"]
|
|
259
|
+
observation = created.get("observation") or {}
|
|
260
|
+
problem = observation.get("problem") or ""
|
|
261
|
+
|
|
262
|
+
tool_calls, raw_response = await _choose_actions(client, provider, model, problem, policy_cfg)
|
|
263
|
+
answer = _tool_to_answer(tool_calls)
|
|
264
|
+
result = await client.step(env_id, tool_calls)
|
|
265
|
+
await client.terminate(env_id)
|
|
266
|
+
|
|
267
|
+
info = result.get("info") or {}
|
|
268
|
+
reward = result.get("reward") or 0.0
|
|
269
|
+
status = info.get("status") or ("correct" if reward > 0 else "incorrect")
|
|
270
|
+
return {
|
|
271
|
+
"seed": seed,
|
|
272
|
+
"split": split,
|
|
273
|
+
"problem": problem,
|
|
274
|
+
"answer": answer,
|
|
275
|
+
"expected": info.get("expected_answer"),
|
|
276
|
+
"reward": reward,
|
|
277
|
+
"status": status,
|
|
278
|
+
"correct": bool(info.get("correct")),
|
|
279
|
+
"raw_response": raw_response,
|
|
280
|
+
"tool_calls": tool_calls,
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
async def eval_via_rollout(
|
|
285
|
+
client: TaskAppClient,
|
|
286
|
+
*,
|
|
287
|
+
run_id: str,
|
|
288
|
+
split: str,
|
|
289
|
+
seed: int | None,
|
|
290
|
+
model: str,
|
|
291
|
+
policy_cfg: dict[str, Any],
|
|
292
|
+
) -> dict[str, Any]:
|
|
293
|
+
payload = {
|
|
294
|
+
"run_id": run_id,
|
|
295
|
+
"env": {
|
|
296
|
+
"env_name": "math",
|
|
297
|
+
"config": {"split": split},
|
|
298
|
+
"seed": seed,
|
|
299
|
+
},
|
|
300
|
+
"policy": {
|
|
301
|
+
"policy_name": "math-single-step",
|
|
302
|
+
"config": policy_cfg,
|
|
303
|
+
},
|
|
304
|
+
"ops": ["agent", "env"],
|
|
305
|
+
"on_done": "terminate",
|
|
306
|
+
}
|
|
307
|
+
resp = await client.rollout(payload)
|
|
308
|
+
trajs = resp.get("trajectories") or []
|
|
309
|
+
if not trajs:
|
|
310
|
+
return {"reward": 0.0, "correct": False, "status": "missing"}
|
|
311
|
+
traj = trajs[0]
|
|
312
|
+
steps = traj.get("steps") or []
|
|
313
|
+
step = steps[0] if steps else {}
|
|
314
|
+
info = step.get("info") or {}
|
|
315
|
+
observation = step.get("obs") or {}
|
|
316
|
+
return {
|
|
317
|
+
"seed": seed,
|
|
318
|
+
"split": split,
|
|
319
|
+
"problem": observation.get("problem"),
|
|
320
|
+
"answer": _tool_to_answer(step.get("tool_calls") or []),
|
|
321
|
+
"expected": info.get("expected_answer"),
|
|
322
|
+
"reward": step.get("reward") or 0.0,
|
|
323
|
+
"status": info.get("status"),
|
|
324
|
+
"correct": bool(info.get("correct")),
|
|
325
|
+
"raw_response": resp,
|
|
326
|
+
"tool_calls": step.get("tool_calls") or [],
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _load_config(path: str | None) -> dict[str, Any]:
|
|
331
|
+
if not path:
|
|
332
|
+
return {}
|
|
333
|
+
with open(path, "rb") as fh:
|
|
334
|
+
return tomllib.load(fh)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def _default_policy_cfg(cfg: dict[str, Any]) -> dict[str, Any]:
|
|
338
|
+
policy = dict(cfg.get("policy") or {})
|
|
339
|
+
if "inference_url" not in policy:
|
|
340
|
+
env_url = os.getenv("INFERENCE_URL")
|
|
341
|
+
if env_url:
|
|
342
|
+
policy["inference_url"] = env_url
|
|
343
|
+
for key in ("max_tokens", "temperature", "top_p", "headers", "extra_headers"):
|
|
344
|
+
if key not in policy and key in cfg:
|
|
345
|
+
policy[key] = cfg[key]
|
|
346
|
+
extra_headers = dict(policy.get("extra_headers") or {})
|
|
347
|
+
headers = dict(policy.get("headers") or {})
|
|
348
|
+
if "Authorization" not in headers and "Authorization" not in extra_headers:
|
|
349
|
+
synth_key = os.getenv("SYNTH_API_KEY")
|
|
350
|
+
if synth_key:
|
|
351
|
+
extra_headers["Authorization"] = f"Bearer {synth_key}"
|
|
352
|
+
if extra_headers:
|
|
353
|
+
policy["extra_headers"] = extra_headers
|
|
354
|
+
return policy
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
async def main() -> None:
|
|
358
|
+
parser = argparse.ArgumentParser(description="Evaluate math task app policies")
|
|
359
|
+
parser.add_argument("--toml", help="Path to TOML config", default=None)
|
|
360
|
+
parser.add_argument("--use-rollout", action="store_true", help="Use server-side rollout")
|
|
361
|
+
args = parser.parse_args()
|
|
362
|
+
|
|
363
|
+
cfg = _load_config(args.toml)
|
|
364
|
+
task_app_url = (cfg.get("task_app_url") or os.getenv("TASK_APP_URL") or "").rstrip("/")
|
|
365
|
+
if not task_app_url:
|
|
366
|
+
raise RuntimeError("task_app_url missing; set in TOML or export TASK_APP_URL")
|
|
367
|
+
model = cfg.get("model") or os.getenv("EVAL_MODEL") or "groq:qwen-2.5-7b"
|
|
368
|
+
split = cfg.get("split") or os.getenv("EVAL_SPLIT") or DEFAULT_SPLIT
|
|
369
|
+
episodes = int(cfg.get("num_episodes") or os.getenv("NUM_EPISODES") or 50)
|
|
370
|
+
seed_start = int(cfg.get("seed_start") or 0)
|
|
371
|
+
|
|
372
|
+
policy_cfg = _default_policy_cfg(cfg)
|
|
373
|
+
provider_hint = (
|
|
374
|
+
cfg.get("provider") or cfg.get("policy", {}).get("provider") or policy_cfg.get("provider")
|
|
375
|
+
)
|
|
376
|
+
provider = _detect_provider(model, provider_hint)
|
|
377
|
+
policy_cfg.pop("provider", None)
|
|
378
|
+
|
|
379
|
+
api_key = os.getenv("ENVIRONMENT_API_KEY")
|
|
380
|
+
|
|
381
|
+
successes = 0
|
|
382
|
+
failures: dict[str, int] = {}
|
|
383
|
+
results: list[dict[str, Any]] = []
|
|
384
|
+
|
|
385
|
+
async with TaskAppClient(task_app_url, api_key=api_key) as client:
|
|
386
|
+
for episode in range(episodes):
|
|
387
|
+
seed = seed_start + episode
|
|
388
|
+
if args.use_rollout:
|
|
389
|
+
data = await eval_via_rollout(
|
|
390
|
+
client,
|
|
391
|
+
run_id=f"eval-{seed}",
|
|
392
|
+
split=split,
|
|
393
|
+
seed=seed,
|
|
394
|
+
model=model,
|
|
395
|
+
policy_cfg={"model": model, **policy_cfg},
|
|
396
|
+
)
|
|
397
|
+
else:
|
|
398
|
+
data = await eval_episode(
|
|
399
|
+
client,
|
|
400
|
+
split=split,
|
|
401
|
+
seed=seed,
|
|
402
|
+
model=model,
|
|
403
|
+
provider=provider,
|
|
404
|
+
policy_cfg={"model": model, **policy_cfg},
|
|
405
|
+
)
|
|
406
|
+
results.append(data)
|
|
407
|
+
if data.get("correct"):
|
|
408
|
+
successes += 1
|
|
409
|
+
status = data.get("status") or "unknown"
|
|
410
|
+
failures[status] = failures.get(status, 0) + (0 if data.get("correct") else 1)
|
|
411
|
+
answer = data.get("answer")
|
|
412
|
+
expected = data.get("expected")
|
|
413
|
+
problem = data.get("problem")
|
|
414
|
+
tool_calls = data.get("tool_calls") or []
|
|
415
|
+
print(
|
|
416
|
+
f"Episode {episode + 1}/{episodes} seed={seed} status={status} reward={data.get('reward')}\n"
|
|
417
|
+
f" problem: {problem!r}\n"
|
|
418
|
+
f" tool : {tool_calls!r}\n"
|
|
419
|
+
f" answer : {answer!r}\n expected: {expected!r}",
|
|
420
|
+
flush=True,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
accuracy = successes / max(episodes, 1)
|
|
424
|
+
print("=== Evaluation Summary ===")
|
|
425
|
+
print(f"Task App: {task_app_url}")
|
|
426
|
+
print(f"Model: {model}")
|
|
427
|
+
print(f"Split: {split}")
|
|
428
|
+
print(f"Episodes: {episodes}")
|
|
429
|
+
print(f"Accuracy: {accuracy:.3f}")
|
|
430
|
+
print("Failure breakdown:")
|
|
431
|
+
for status, count in sorted(failures.items(), key=lambda kv: (-kv[1], kv[0])):
|
|
432
|
+
print(f" {status}: {count}")
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
if __name__ == "__main__":
|
|
436
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Submit math RL training jobs via Synth backend."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import tomllib
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import requests
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _load_toml(path: Path) -> dict[str, Any]:
|
|
18
|
+
if not path.exists():
|
|
19
|
+
print(f"config not found: {path}", file=sys.stderr)
|
|
20
|
+
sys.exit(2)
|
|
21
|
+
with path.open("rb") as fh:
|
|
22
|
+
return tomllib.load(fh)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def main() -> None:
|
|
26
|
+
parser = argparse.ArgumentParser(description="Create math RL job via backend RL endpoint")
|
|
27
|
+
parser.add_argument(
|
|
28
|
+
"--backend", default=os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api")
|
|
29
|
+
)
|
|
30
|
+
parser.add_argument("--config", required=True, help="Path to RL TOML config")
|
|
31
|
+
parser.add_argument(
|
|
32
|
+
"--task-url", default=os.getenv("TASK_APP_URL", ""), help="Override task service URL"
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"--idempotency",
|
|
36
|
+
default=os.getenv("RL_IDEMPOTENCY_KEY", ""),
|
|
37
|
+
help="Optional Idempotency-Key header",
|
|
38
|
+
)
|
|
39
|
+
args = parser.parse_args()
|
|
40
|
+
|
|
41
|
+
cfg_path = Path(args.config).expanduser()
|
|
42
|
+
cfg = _load_toml(cfg_path)
|
|
43
|
+
|
|
44
|
+
services = cfg.get("services") if isinstance(cfg.get("services"), dict) else {}
|
|
45
|
+
|
|
46
|
+
task_url = (
|
|
47
|
+
(args.task_url or "").strip()
|
|
48
|
+
or (os.getenv("TASK_APP_URL") or "").strip()
|
|
49
|
+
or (services.get("task_url") or "").strip()
|
|
50
|
+
)
|
|
51
|
+
if not task_url:
|
|
52
|
+
print(
|
|
53
|
+
"Missing task service URL. Provide --task-url or set TASK_APP_URL or services.task_url in TOML",
|
|
54
|
+
file=sys.stderr,
|
|
55
|
+
)
|
|
56
|
+
sys.exit(2)
|
|
57
|
+
|
|
58
|
+
model_cfg = cfg.get("model") if isinstance(cfg.get("model"), dict) else {}
|
|
59
|
+
has_source = bool((model_cfg.get("source") or "").strip())
|
|
60
|
+
has_base = bool((model_cfg.get("base") or "").strip())
|
|
61
|
+
if has_source == has_base:
|
|
62
|
+
print(
|
|
63
|
+
"Model section must specify exactly one of [model].source or [model].base",
|
|
64
|
+
file=sys.stderr,
|
|
65
|
+
)
|
|
66
|
+
sys.exit(2)
|
|
67
|
+
|
|
68
|
+
payload: dict[str, Any] = {
|
|
69
|
+
"job_type": "rl",
|
|
70
|
+
"compute": cfg.get("compute", {}),
|
|
71
|
+
"data": {
|
|
72
|
+
"endpoint_base_url": task_url,
|
|
73
|
+
"config": cfg,
|
|
74
|
+
},
|
|
75
|
+
"tags": cfg.get("tags", {}),
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
backend = str(args.backend).rstrip("/")
|
|
79
|
+
url = f"{backend}/rl/jobs"
|
|
80
|
+
api_key = (os.getenv("SYNTH_API_KEY") or os.getenv("SYNTH_KEY") or "").strip()
|
|
81
|
+
if not api_key:
|
|
82
|
+
print("Missing SYNTH_API_KEY in env", file=sys.stderr)
|
|
83
|
+
sys.exit(2)
|
|
84
|
+
|
|
85
|
+
headers = {
|
|
86
|
+
"content-type": "application/json",
|
|
87
|
+
"authorization": f"Bearer {api_key}",
|
|
88
|
+
}
|
|
89
|
+
idem = (args.idempotency or "").strip()
|
|
90
|
+
if idem:
|
|
91
|
+
headers["Idempotency-Key"] = idem
|
|
92
|
+
|
|
93
|
+
print(f"[INFO] POST {url}")
|
|
94
|
+
try:
|
|
95
|
+
preview = {"job_type": payload["job_type"], "data": {"config_keys": list(cfg.keys())}}
|
|
96
|
+
print(f"[INFO] Payload preview: {json.dumps(preview)}")
|
|
97
|
+
except Exception:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
resp = requests.post(url, headers=headers, json=payload, timeout=120)
|
|
101
|
+
ok = resp.status_code in (200, 201)
|
|
102
|
+
try:
|
|
103
|
+
snippet = resp.json()
|
|
104
|
+
except Exception:
|
|
105
|
+
snippet = resp.text[:300]
|
|
106
|
+
print(f"[INFO] Response: {resp.status_code} {snippet}")
|
|
107
|
+
sys.exit(0 if ok else 1)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
if __name__ == "__main__":
|
|
111
|
+
main()
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Math Single-Step Task App
|
|
2
|
+
|
|
3
|
+
This directory hosts the legacy entrypoint for the math single-step task app. Prefer starting the app via:
|
|
4
|
+
|
|
5
|
+
```bash
|
|
6
|
+
uvx synth-ai serve math-single-step --env-file examples/rl/.env --port 8101
|
|
7
|
+
```
|
|
8
|
+
|
|
9
|
+
If you need to run it directly (e.g., for Modal `modal deploy` compatibility), use:
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
python examples/rl/task_app/math_task_app.py --env-file examples/rl/.env --port 8101
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
Environment variables:
|
|
16
|
+
|
|
17
|
+
- `MATH_DATASET_NAME` – defaults to `EleutherAI/math`
|
|
18
|
+
- `MATH_DATASET_CONFIG` – defaults to `algebra__linear_1d`
|
|
19
|
+
- `MATH_DATASET_DEFAULT_SPLIT`, `MATH_DATASET_VALIDATION_SPLIT`, `MATH_DATASET_TEST_SPLIT`
|
|
20
|
+
|
|
21
|
+
The task app enforces a single `math_submit` tool call per episode, enabling RL to reward correct final answers and penalise missing or malformed submissions.
|
|
22
|
+
|