synth-ai 0.2.8.dev12__py3-none-any.whl → 0.2.9.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/api/train/__init__.py +5 -0
- synth_ai/api/train/builders.py +165 -0
- synth_ai/api/train/cli.py +450 -0
- synth_ai/api/train/config_finder.py +168 -0
- synth_ai/api/train/env_resolver.py +302 -0
- synth_ai/api/train/pollers.py +66 -0
- synth_ai/api/train/task_app.py +193 -0
- synth_ai/api/train/utils.py +232 -0
- synth_ai/cli/__init__.py +23 -0
- synth_ai/cli/rl_demo.py +18 -6
- synth_ai/cli/root.py +38 -6
- synth_ai/cli/task_apps.py +1107 -0
- synth_ai/demo_registry.py +258 -0
- synth_ai/demos/core/cli.py +147 -111
- synth_ai/demos/demo_task_apps/__init__.py +7 -1
- synth_ai/demos/demo_task_apps/math/config.toml +55 -110
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +157 -21
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +39 -0
- synth_ai/task/__init__.py +94 -1
- synth_ai/task/apps/__init__.py +88 -0
- synth_ai/task/apps/grpo_crafter.py +438 -0
- synth_ai/task/apps/math_single_step.py +852 -0
- synth_ai/task/auth.py +153 -0
- synth_ai/task/client.py +165 -0
- synth_ai/task/contracts.py +29 -14
- synth_ai/task/datasets.py +105 -0
- synth_ai/task/errors.py +49 -0
- synth_ai/task/json.py +77 -0
- synth_ai/task/proxy.py +258 -0
- synth_ai/task/rubrics.py +212 -0
- synth_ai/task/server.py +398 -0
- synth_ai/task/tracing_utils.py +79 -0
- synth_ai/task/vendors.py +61 -0
- synth_ai/tracing_v3/session_tracer.py +13 -5
- synth_ai/tracing_v3/storage/base.py +10 -12
- synth_ai/tracing_v3/turso/manager.py +20 -6
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/METADATA +3 -2
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/RECORD +42 -18
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/top_level.txt +0 -0
|
@@ -1,129 +1,74 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
trainer_mode = "full"
|
|
6
|
-
|
|
7
|
-
[lora]
|
|
8
|
-
r = 16
|
|
9
|
-
alpha = 32
|
|
10
|
-
dropout = 0.05
|
|
11
|
-
target_modules = [
|
|
12
|
-
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
13
|
-
"gate_proj", "up_proj", "down_proj",
|
|
14
|
-
]
|
|
15
|
-
|
|
16
|
-
[rdma]
|
|
17
|
-
enabled = false
|
|
18
|
-
ifname = "eth0"
|
|
19
|
-
ip_type = "ipv4"
|
|
20
|
-
p2p_disable = 0
|
|
21
|
-
shm_disable = 0
|
|
22
|
-
fast_nccl = false
|
|
23
|
-
|
|
24
|
-
gid_index = 3
|
|
25
|
-
cross_nic = 0
|
|
26
|
-
collnet_enable = 0
|
|
27
|
-
net_gdr_level = 2
|
|
28
|
-
|
|
29
|
-
nsocks_perthread = 4
|
|
30
|
-
socket_nthreads = 2
|
|
1
|
+
[algorithm]
|
|
2
|
+
type = "online"
|
|
3
|
+
method = "policy_gradient"
|
|
4
|
+
variety = "gspo"
|
|
31
5
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
p2p_level = "SYS"
|
|
35
|
-
debug = "INFO"
|
|
6
|
+
[services]
|
|
7
|
+
task_url = "http://localhost:8101"
|
|
36
8
|
|
|
37
|
-
[
|
|
38
|
-
|
|
39
|
-
gpu_index = 1
|
|
40
|
-
port = 8002
|
|
41
|
-
tp = 1
|
|
42
|
-
health_max_wait_s = 180
|
|
43
|
-
health_interval_ms = 300
|
|
9
|
+
[model]
|
|
10
|
+
base = "Qwen/Qwen3-1.7B"
|
|
44
11
|
|
|
45
|
-
[
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
12
|
+
[policy]
|
|
13
|
+
model = "Qwen/Qwen3-1.7B"
|
|
14
|
+
inference_url = "http://localhost:8000/api/inference"
|
|
15
|
+
max_tokens = 1028
|
|
16
|
+
temperature = 0.2
|
|
17
|
+
|
|
18
|
+
[data]
|
|
19
|
+
split = "train"
|
|
20
|
+
seed_start = 0
|
|
21
|
+
episodes_per_iteration = 1280 # 8 per group * 4 groups per batch * 2 batches per step * 20 steps
|
|
22
|
+
evaluation_split = "validation"
|
|
23
|
+
evaluation_episodes = 50
|
|
53
24
|
|
|
54
25
|
[training]
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
batch_size =
|
|
58
|
-
group_size =
|
|
26
|
+
max_turns = 1
|
|
27
|
+
ops = ["agent", "env"]
|
|
28
|
+
batch_size = 2
|
|
29
|
+
group_size = 16
|
|
30
|
+
reward_positive = 1.0
|
|
31
|
+
reward_negative_no_tool = -1.0
|
|
32
|
+
reward_negative_no_answer = -0.5
|
|
59
33
|
learning_rate = 5e-6
|
|
60
|
-
max_grad_norm = 0.5
|
|
61
34
|
log_interval = 1
|
|
62
|
-
update_reference_interval = 0
|
|
63
35
|
weight_sync_interval = 1
|
|
64
36
|
|
|
65
37
|
[training.weight_sync]
|
|
66
38
|
enable = true
|
|
67
39
|
targets = ["policy"]
|
|
68
40
|
|
|
69
|
-
[
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
max_model_len = 8192
|
|
73
|
-
max_num_seqs = 32
|
|
74
|
-
enforce_eager = false
|
|
75
|
-
max_parallel_generations = 4
|
|
41
|
+
[compute]
|
|
42
|
+
gpu_type = "H100"
|
|
43
|
+
gpu_count = 4
|
|
76
44
|
|
|
77
|
-
[
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
every_n_iters = 5
|
|
84
|
-
|
|
85
|
-
[rollout]
|
|
86
|
-
env_name = "math"
|
|
87
|
-
policy_name = "math-react"
|
|
88
|
-
env_config = {}
|
|
89
|
-
max_steps_per_episode = 5
|
|
90
|
-
sampling_temperature = 0.3
|
|
91
|
-
sampling_top_p = 0.95
|
|
92
|
-
max_tokens = 1024
|
|
93
|
-
max_concurrent_rollouts = 4
|
|
94
|
-
ops_per_rollout = 14
|
|
95
|
-
on_done = "reset"
|
|
96
|
-
thinking_mode = "think"
|
|
97
|
-
thinking_budget = 512
|
|
45
|
+
[topology]
|
|
46
|
+
type = "single_node_split"
|
|
47
|
+
gpus_for_vllm = 2
|
|
48
|
+
gpus_for_training = 1
|
|
49
|
+
gpus_for_ref = 1
|
|
50
|
+
tensor_parallel = 1
|
|
98
51
|
|
|
99
|
-
[
|
|
100
|
-
|
|
52
|
+
[vllm]
|
|
53
|
+
tensor_parallel_size = 1
|
|
54
|
+
max_model_len = 4096
|
|
101
55
|
|
|
102
|
-
[
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
advantage_normalization = true
|
|
109
|
-
group_normalization = true
|
|
110
|
-
num_inner_steps = 1
|
|
111
|
-
clip_epsilon = 0.2
|
|
112
|
-
completion_only = false
|
|
56
|
+
[reference]
|
|
57
|
+
placement = "dedicated"
|
|
58
|
+
port = 8002
|
|
59
|
+
tp = 1
|
|
60
|
+
health_max_wait_s = 180
|
|
61
|
+
health_interval_ms = 300
|
|
113
62
|
|
|
114
|
-
[
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
indicator_lambda = 0.0
|
|
63
|
+
[rollout]
|
|
64
|
+
policy_name = "math-single-step"
|
|
65
|
+
max_turns = 1
|
|
66
|
+
episodes_per_batch = 32 # group_size * batch_size
|
|
119
67
|
|
|
120
|
-
[
|
|
121
|
-
|
|
68
|
+
[evaluation]
|
|
69
|
+
instances = 32
|
|
70
|
+
every_n_iters = 10
|
|
71
|
+
seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
|
|
122
72
|
|
|
123
|
-
[
|
|
124
|
-
|
|
125
|
-
directory = "/checkpoints"
|
|
126
|
-
keep_last_n = 3
|
|
127
|
-
save_optimizer = true
|
|
128
|
-
save_scheduler = true
|
|
129
|
-
enabled = true
|
|
73
|
+
[tags]
|
|
74
|
+
experiment = "math_single_step_qwen17"
|
|
@@ -7,6 +7,9 @@ from pathlib import Path
|
|
|
7
7
|
|
|
8
8
|
from modal import App, Image, Secret, asgi_app
|
|
9
9
|
from functools import lru_cache
|
|
10
|
+
from typing import Iterable
|
|
11
|
+
|
|
12
|
+
from starlette.requests import Request
|
|
10
13
|
|
|
11
14
|
try: # Backward compatibility with older installed SDKs
|
|
12
15
|
from synth_ai.demos.demo_task_apps.core import DEFAULT_TASK_APP_SECRET_NAME
|
|
@@ -95,18 +98,77 @@ app = App("hendrycks-math-task-app")
|
|
|
95
98
|
def fastapi_app():
|
|
96
99
|
import httpx
|
|
97
100
|
from fastapi import Body, HTTPException, status
|
|
98
|
-
from fastapi import FastAPI
|
|
101
|
+
from fastapi import FastAPI
|
|
99
102
|
from fastapi.middleware.cors import CORSMiddleware
|
|
100
103
|
from fastapi.responses import JSONResponse
|
|
104
|
+
try:
|
|
105
|
+
from synth_ai.task.auth import (
|
|
106
|
+
is_api_key_header_authorized,
|
|
107
|
+
normalize_environment_api_key,
|
|
108
|
+
)
|
|
109
|
+
except Exception: # pragma: no cover - fallback for older synth-ai builds
|
|
110
|
+
def _normalize_env_key_fallback() -> str | None:
|
|
111
|
+
key = os.getenv("ENVIRONMENT_API_KEY")
|
|
112
|
+
if key:
|
|
113
|
+
return key
|
|
114
|
+
for alias in ("dev_environment_api_key", "DEV_ENVIRONMENT_API_KEY"):
|
|
115
|
+
candidate = os.getenv(alias)
|
|
116
|
+
if candidate:
|
|
117
|
+
os.environ["ENVIRONMENT_API_KEY"] = candidate
|
|
118
|
+
return candidate
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
def normalize_environment_api_key() -> str | None: # type: ignore[override]
|
|
122
|
+
return _normalize_env_key_fallback()
|
|
123
|
+
|
|
124
|
+
def _header_values(request: Request, header: str) -> Iterable[str]:
|
|
125
|
+
raw = request.headers.get(header) or request.headers.get(header.lower())
|
|
126
|
+
return [raw] if raw is not None else []
|
|
127
|
+
|
|
128
|
+
def _split(values: Iterable[str]) -> list[str]:
|
|
129
|
+
parts: list[str] = []
|
|
130
|
+
for value in values:
|
|
131
|
+
if not isinstance(value, str):
|
|
132
|
+
continue
|
|
133
|
+
for chunk in value.split(','):
|
|
134
|
+
chunk = chunk.strip()
|
|
135
|
+
if chunk:
|
|
136
|
+
parts.append(chunk)
|
|
137
|
+
return parts
|
|
138
|
+
|
|
139
|
+
def is_api_key_header_authorized(request: Request) -> bool: # type: ignore[override]
|
|
140
|
+
expected = normalize_environment_api_key()
|
|
141
|
+
if not expected:
|
|
142
|
+
return False
|
|
143
|
+
single = _header_values(request, "x-api-key")
|
|
144
|
+
multi = _header_values(request, "x-api-keys")
|
|
145
|
+
auth = _header_values(request, "authorization")
|
|
146
|
+
bearer = []
|
|
147
|
+
for token in auth:
|
|
148
|
+
if isinstance(token, str) and token.lower().startswith("bearer "):
|
|
149
|
+
bearer.append(token.split(" ", 1)[1].strip())
|
|
150
|
+
candidates = _split(single + multi + bearer)
|
|
151
|
+
return any(candidate == expected for candidate in candidates)
|
|
101
152
|
|
|
102
153
|
# Inline, self-contained FastAPI app (math-only)
|
|
103
154
|
@lru_cache(maxsize=1)
|
|
104
155
|
def _hf_split(subject: str, split: str, slice_spec: str | None = None):
|
|
105
156
|
from datasets import load_dataset # type: ignore
|
|
157
|
+
|
|
106
158
|
s = split
|
|
107
159
|
if slice_spec:
|
|
108
160
|
s = f"{s}{slice_spec}"
|
|
109
|
-
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
return load_dataset("nlile/hendrycks-MATH-benchmark", subject, split=s)
|
|
164
|
+
except ValueError:
|
|
165
|
+
base = load_dataset("nlile/hendrycks-MATH-benchmark", split=s)
|
|
166
|
+
if subject and subject not in {"", "default"}:
|
|
167
|
+
if "subject" in base.column_names:
|
|
168
|
+
base = base.filter(lambda ex: ex.get("subject") == subject)
|
|
169
|
+
elif isinstance(base, list):
|
|
170
|
+
base = [ex for ex in base if ex.get("subject") == subject]
|
|
171
|
+
return base
|
|
110
172
|
|
|
111
173
|
def _normalize_answer_text(s: str) -> str:
|
|
112
174
|
import re as _re
|
|
@@ -121,6 +183,9 @@ def fastapi_app():
|
|
|
121
183
|
subj = subject or os.getenv("HENDRYCKS_MATH_CONFIG", "default")
|
|
122
184
|
ds = _hf_split(subj, os.getenv("HENDRYCKS_MATH_SPLIT", "test"), os.getenv("HENDRYCKS_MATH_SLICE"))
|
|
123
185
|
n = len(ds) if hasattr(ds, "__len__") else 0
|
|
186
|
+
if n == 0 and subject not in {"", "default"}:
|
|
187
|
+
ds = _hf_split("default", os.getenv("HENDRYCKS_MATH_SPLIT", "test"), os.getenv("HENDRYCKS_MATH_SLICE"))
|
|
188
|
+
n = len(ds) if hasattr(ds, "__len__") else 0
|
|
124
189
|
if n == 0:
|
|
125
190
|
raise RuntimeError("Hendrycks MATH dataset loaded empty")
|
|
126
191
|
idx = abs(int(seed)) % n
|
|
@@ -158,6 +223,53 @@ def fastapi_app():
|
|
|
158
223
|
logger.info(msg)
|
|
159
224
|
return prefix
|
|
160
225
|
|
|
226
|
+
def _resolve_env_keys() -> set[str]:
|
|
227
|
+
keys: set[str] = set()
|
|
228
|
+
for alias in ("ENVIRONMENT_API_KEY", "dev_environment_api_key", "DEV_ENVIRONMENT_API_KEY"):
|
|
229
|
+
value = os.environ.get(alias)
|
|
230
|
+
if value:
|
|
231
|
+
os.environ.setdefault("ENVIRONMENT_API_KEY", value)
|
|
232
|
+
keys.add(value)
|
|
233
|
+
alias_env = os.environ.get("ENVIRONMENT_API_KEY_ALIASES", "")
|
|
234
|
+
for chunk in alias_env.split(","):
|
|
235
|
+
trimmed = chunk.strip()
|
|
236
|
+
if trimmed:
|
|
237
|
+
keys.add(trimmed)
|
|
238
|
+
return keys
|
|
239
|
+
|
|
240
|
+
def _extract_header_candidates(
|
|
241
|
+
request: Request,
|
|
242
|
+
x_api_key: str | None,
|
|
243
|
+
x_api_keys: str | None,
|
|
244
|
+
authorization: str | None,
|
|
245
|
+
) -> list[str]:
|
|
246
|
+
headers = request.headers
|
|
247
|
+
candidates: list[str] = []
|
|
248
|
+
primary = x_api_key or headers.get("x-api-key")
|
|
249
|
+
if primary:
|
|
250
|
+
candidates.append(primary.strip())
|
|
251
|
+
secondary = x_api_keys or headers.get("x-api-keys")
|
|
252
|
+
if secondary:
|
|
253
|
+
candidates.extend([value.strip() for value in secondary.split(",") if value.strip()])
|
|
254
|
+
auth_header = authorization or headers.get("authorization") or headers.get("Authorization")
|
|
255
|
+
if auth_header and auth_header.lower().startswith("bearer "):
|
|
256
|
+
token = auth_header.split(" ", 1)[1].strip()
|
|
257
|
+
if token:
|
|
258
|
+
candidates.append(token)
|
|
259
|
+
return [c for c in candidates if c]
|
|
260
|
+
|
|
261
|
+
def _is_authorized(
|
|
262
|
+
request: Request,
|
|
263
|
+
x_api_key: str | None,
|
|
264
|
+
x_api_keys: str | None,
|
|
265
|
+
authorization: str | None,
|
|
266
|
+
) -> bool:
|
|
267
|
+
keys = _resolve_env_keys()
|
|
268
|
+
if not keys:
|
|
269
|
+
return False
|
|
270
|
+
candidates = _extract_header_candidates(request, x_api_key, x_api_keys, authorization)
|
|
271
|
+
return any(candidate in keys for candidate in candidates)
|
|
272
|
+
|
|
161
273
|
@app.get("/info")
|
|
162
274
|
async def info():
|
|
163
275
|
return {
|
|
@@ -166,42 +278,47 @@ def fastapi_app():
|
|
|
166
278
|
}
|
|
167
279
|
|
|
168
280
|
@app.get("/health")
|
|
169
|
-
async def health(
|
|
170
|
-
|
|
281
|
+
async def health(request: Request):
|
|
282
|
+
env_keys = _resolve_env_keys()
|
|
283
|
+
env_key = next(iter(env_keys), None)
|
|
171
284
|
if not env_key:
|
|
172
285
|
return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
|
|
173
|
-
|
|
286
|
+
# Authorize using all header variants; avoid typed Header params to prevent 422s
|
|
287
|
+
authorized = is_api_key_header_authorized(request)
|
|
288
|
+
if not authorized:
|
|
174
289
|
prefix = _log_env_key_prefix("health", env_key)
|
|
175
|
-
content = {
|
|
176
|
-
|
|
290
|
+
content = {
|
|
291
|
+
"status": "healthy",
|
|
292
|
+
"authorized": False,
|
|
293
|
+
}
|
|
177
294
|
if prefix:
|
|
178
|
-
content["detail"] = f"Invalid API key (expected prefix: {prefix})"
|
|
179
295
|
content["expected_api_key_prefix"] = prefix
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
return {"status": "healthy"}
|
|
296
|
+
return JSONResponse(status_code=200, content=content)
|
|
297
|
+
return {"status": "healthy", "authorized": True}
|
|
183
298
|
|
|
184
299
|
# Optional rollout-specific health for CLI compatibility
|
|
185
300
|
@app.get("/health/rollout")
|
|
186
|
-
async def health_rollout(
|
|
187
|
-
|
|
301
|
+
async def health_rollout(request: Request):
|
|
302
|
+
env_keys = _resolve_env_keys()
|
|
303
|
+
env_key = next(iter(env_keys), None)
|
|
188
304
|
if not env_key:
|
|
189
305
|
return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
|
|
190
|
-
|
|
306
|
+
authorized = is_api_key_header_authorized(request)
|
|
307
|
+
if not authorized:
|
|
191
308
|
prefix = _log_env_key_prefix("health/rollout", env_key)
|
|
192
|
-
content = {
|
|
193
|
-
|
|
309
|
+
content = {
|
|
310
|
+
"status": "healthy",
|
|
311
|
+
"authorized": False,
|
|
312
|
+
}
|
|
194
313
|
if prefix:
|
|
195
|
-
content["detail"] = f"Invalid or missing API key (expected prefix: {prefix})"
|
|
196
314
|
content["expected_api_key_prefix"] = prefix
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
return {"ok": True}
|
|
315
|
+
return JSONResponse(status_code=200, content=content)
|
|
316
|
+
return {"ok": True, "authorized": True}
|
|
200
317
|
|
|
201
318
|
# _load_hendrycks_problem is defined at fastapi_app scope
|
|
202
319
|
|
|
203
320
|
@app.get("/task_info")
|
|
204
|
-
async def task_info(seed: int = 0, subject: str = "
|
|
321
|
+
async def task_info(seed: int = 0, subject: str = "default"):
|
|
205
322
|
"""Return Hendrycks MATH problem/answer and tool schema for a seed."""
|
|
206
323
|
q, a = _load_hendrycks_problem(int(seed), subject=subject)
|
|
207
324
|
tools = [{
|
|
@@ -229,6 +346,25 @@ def fastapi_app():
|
|
|
229
346
|
|
|
230
347
|
api = create_app()
|
|
231
348
|
|
|
349
|
+
# Always log and surface 422 validation errors with header presence snapshot
|
|
350
|
+
from fastapi.exceptions import RequestValidationError
|
|
351
|
+
|
|
352
|
+
@api.exception_handler(RequestValidationError)
|
|
353
|
+
async def _on_validation_error(request: Request, exc: RequestValidationError):
|
|
354
|
+
try:
|
|
355
|
+
hdr = request.headers
|
|
356
|
+
snapshot = {
|
|
357
|
+
"path": str(getattr(request, "url").path),
|
|
358
|
+
"have_x_api_key": bool(hdr.get("x-api-key")),
|
|
359
|
+
"have_x_api_keys": bool(hdr.get("x-api-keys")),
|
|
360
|
+
"have_authorization": bool(hdr.get("authorization")),
|
|
361
|
+
"errors": exc.errors()[:5],
|
|
362
|
+
}
|
|
363
|
+
print("[422] validation", snapshot, flush=True)
|
|
364
|
+
except Exception:
|
|
365
|
+
pass
|
|
366
|
+
return JSONResponse(status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]})
|
|
367
|
+
|
|
232
368
|
@api.get("/")
|
|
233
369
|
async def root_probe():
|
|
234
370
|
return {"status": "ok", "service": "math"}
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Task app registry entry for the math demo Modal deployment."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
6
|
+
from synth_ai.task.apps.math_single_step import build_config as base_build_config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
DEMO_MODAL_CONFIG = ModalDeploymentConfig(
|
|
10
|
+
app_name="hendrycks-math-task-app",
|
|
11
|
+
pip_packages=(
|
|
12
|
+
"fastapi>=0.110.0",
|
|
13
|
+
"uvicorn>=0.23.0",
|
|
14
|
+
"pydantic>=2.6.0",
|
|
15
|
+
"httpx>=0.24.0",
|
|
16
|
+
"numpy>=1.24.0",
|
|
17
|
+
"aiohttp>=3.8.0",
|
|
18
|
+
"datasets>=2.16.0",
|
|
19
|
+
"synth-ai",
|
|
20
|
+
),
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_config():
|
|
25
|
+
"""Reuse the shared math single-step TaskAppConfig."""
|
|
26
|
+
|
|
27
|
+
return base_build_config()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
register_task_app(
|
|
31
|
+
entry=TaskAppEntry(
|
|
32
|
+
app_id="hendrycks-math-demo",
|
|
33
|
+
description="Demo math task app (Modal-focused) shipping with synth-ai demos.",
|
|
34
|
+
config_factory=build_config,
|
|
35
|
+
env_files=("examples/rl/.env",),
|
|
36
|
+
modal=DEMO_MODAL_CONFIG,
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
|
synth_ai/task/__init__.py
CHANGED
|
@@ -1,10 +1,103 @@
|
|
|
1
1
|
from .validators import validate_task_app_url
|
|
2
2
|
from .health import task_app_health
|
|
3
|
-
from .contracts import
|
|
3
|
+
from .contracts import (
|
|
4
|
+
TaskAppContract,
|
|
5
|
+
TaskAppEndpoints,
|
|
6
|
+
RolloutEnvSpec,
|
|
7
|
+
RolloutPolicySpec,
|
|
8
|
+
RolloutRecordConfig,
|
|
9
|
+
RolloutSafetyConfig,
|
|
10
|
+
RolloutRequest,
|
|
11
|
+
RolloutResponse,
|
|
12
|
+
RolloutTrajectory,
|
|
13
|
+
RolloutStep,
|
|
14
|
+
RolloutMetrics,
|
|
15
|
+
TaskInfo,
|
|
16
|
+
)
|
|
17
|
+
from .json import to_jsonable
|
|
18
|
+
from .auth import (
|
|
19
|
+
normalize_environment_api_key,
|
|
20
|
+
is_api_key_header_authorized,
|
|
21
|
+
require_api_key_dependency,
|
|
22
|
+
)
|
|
23
|
+
from .vendors import (
|
|
24
|
+
normalize_vendor_keys,
|
|
25
|
+
get_openai_key_or_503,
|
|
26
|
+
get_groq_key_or_503,
|
|
27
|
+
)
|
|
28
|
+
from .proxy import (
|
|
29
|
+
INTERACT_TOOL_SCHEMA,
|
|
30
|
+
prepare_for_openai,
|
|
31
|
+
prepare_for_groq,
|
|
32
|
+
inject_system_hint,
|
|
33
|
+
extract_message_text,
|
|
34
|
+
parse_tool_call_from_text,
|
|
35
|
+
synthesize_tool_call_if_missing,
|
|
36
|
+
)
|
|
37
|
+
from .datasets import TaskDatasetSpec, TaskDatasetRegistry
|
|
38
|
+
from .rubrics import (
|
|
39
|
+
Criterion,
|
|
40
|
+
Rubric,
|
|
41
|
+
load_rubric,
|
|
42
|
+
blend_rubrics,
|
|
43
|
+
score_events_against_rubric,
|
|
44
|
+
score_outcome_against_rubric,
|
|
45
|
+
)
|
|
46
|
+
from .client import TaskAppClient
|
|
47
|
+
from .errors import error_payload, http_exception, json_error_response
|
|
4
48
|
|
|
49
|
+
|
|
50
|
+
from .server import (
|
|
51
|
+
TaskAppConfig,
|
|
52
|
+
ProxyConfig,
|
|
53
|
+
RubricBundle,
|
|
54
|
+
create_task_app,
|
|
55
|
+
run_task_app,
|
|
56
|
+
)
|
|
5
57
|
__all__ = [
|
|
6
58
|
"validate_task_app_url",
|
|
7
59
|
"task_app_health",
|
|
8
60
|
"TaskAppContract",
|
|
9
61
|
"TaskAppEndpoints",
|
|
62
|
+
"RolloutEnvSpec",
|
|
63
|
+
"RolloutPolicySpec",
|
|
64
|
+
"RolloutRecordConfig",
|
|
65
|
+
"RolloutSafetyConfig",
|
|
66
|
+
"RolloutRequest",
|
|
67
|
+
"RolloutResponse",
|
|
68
|
+
"RolloutTrajectory",
|
|
69
|
+
"RolloutStep",
|
|
70
|
+
"RolloutMetrics",
|
|
71
|
+
"TaskInfo",
|
|
72
|
+
"to_jsonable",
|
|
73
|
+
"normalize_environment_api_key",
|
|
74
|
+
"is_api_key_header_authorized",
|
|
75
|
+
"require_api_key_dependency",
|
|
76
|
+
"normalize_vendor_keys",
|
|
77
|
+
"get_openai_key_or_503",
|
|
78
|
+
"get_groq_key_or_503",
|
|
79
|
+
"INTERACT_TOOL_SCHEMA",
|
|
80
|
+
"prepare_for_openai",
|
|
81
|
+
"prepare_for_groq",
|
|
82
|
+
"inject_system_hint",
|
|
83
|
+
"extract_message_text",
|
|
84
|
+
"parse_tool_call_from_text",
|
|
85
|
+
"synthesize_tool_call_if_missing",
|
|
86
|
+
"TaskDatasetSpec",
|
|
87
|
+
"TaskDatasetRegistry",
|
|
88
|
+
"Criterion",
|
|
89
|
+
"Rubric",
|
|
90
|
+
"load_rubric",
|
|
91
|
+
"blend_rubrics",
|
|
92
|
+
"score_events_against_rubric",
|
|
93
|
+
"score_outcome_against_rubric",
|
|
94
|
+
"TaskAppClient",
|
|
95
|
+
"error_payload",
|
|
96
|
+
"http_exception",
|
|
97
|
+
"json_error_response",
|
|
98
|
+
"run_task_app",
|
|
99
|
+
"create_task_app",
|
|
100
|
+
"RubricBundle",
|
|
101
|
+
"ProxyConfig",
|
|
102
|
+
"TaskAppConfig",
|
|
10
103
|
]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Registry for Task Apps exposed via the shared FastAPI harness."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Callable, Dict, Iterable, List, Sequence
|
|
7
|
+
|
|
8
|
+
from ..server import TaskAppConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(slots=True)
|
|
12
|
+
class ModalDeploymentConfig:
|
|
13
|
+
"""Modal deployment defaults for a task app."""
|
|
14
|
+
|
|
15
|
+
app_name: str
|
|
16
|
+
python_version: str = "3.11"
|
|
17
|
+
pip_packages: Sequence[str] = field(default_factory=tuple)
|
|
18
|
+
extra_local_dirs: Sequence[tuple[str, str]] = field(default_factory=tuple)
|
|
19
|
+
secret_names: Sequence[str] = field(default_factory=tuple)
|
|
20
|
+
volume_mounts: Sequence[tuple[str, str]] = field(default_factory=tuple)
|
|
21
|
+
timeout: int = 600
|
|
22
|
+
memory: int = 4096
|
|
23
|
+
cpu: float = 2.0
|
|
24
|
+
min_containers: int = 1
|
|
25
|
+
max_containers: int = 4
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(slots=True)
|
|
29
|
+
class TaskAppEntry:
|
|
30
|
+
"""Metadata describing a registered task app."""
|
|
31
|
+
|
|
32
|
+
app_id: str
|
|
33
|
+
description: str
|
|
34
|
+
config_factory: Callable[[], TaskAppConfig]
|
|
35
|
+
aliases: Sequence[str] = field(default_factory=tuple)
|
|
36
|
+
env_files: Sequence[str] = field(default_factory=tuple)
|
|
37
|
+
modal: ModalDeploymentConfig | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TaskAppRegistry:
|
|
41
|
+
"""In-memory registry of known task apps."""
|
|
42
|
+
|
|
43
|
+
def __init__(self) -> None:
|
|
44
|
+
self._entries: Dict[str, TaskAppEntry] = {}
|
|
45
|
+
self._alias_to_id: Dict[str, str] = {}
|
|
46
|
+
|
|
47
|
+
def register(self, entry: TaskAppEntry) -> None:
|
|
48
|
+
if entry.app_id in self._entries:
|
|
49
|
+
raise ValueError(f"Task app already registered: {entry.app_id}")
|
|
50
|
+
self._entries[entry.app_id] = entry
|
|
51
|
+
for alias in entry.aliases:
|
|
52
|
+
if alias in self._alias_to_id:
|
|
53
|
+
raise ValueError(f"Alias already registered: {alias}")
|
|
54
|
+
self._alias_to_id[alias] = entry.app_id
|
|
55
|
+
|
|
56
|
+
def get(self, app_id: str) -> TaskAppEntry:
|
|
57
|
+
resolved = self._alias_to_id.get(app_id, app_id)
|
|
58
|
+
if resolved not in self._entries:
|
|
59
|
+
raise KeyError(f"Unknown task app id: {app_id}")
|
|
60
|
+
return self._entries[resolved]
|
|
61
|
+
|
|
62
|
+
def list(self) -> List[TaskAppEntry]:
|
|
63
|
+
return sorted(self._entries.values(), key=lambda entry: entry.app_id)
|
|
64
|
+
|
|
65
|
+
def __iter__(self) -> Iterable[TaskAppEntry]:
|
|
66
|
+
return iter(self.list())
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
registry = TaskAppRegistry()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def register_task_app(*, entry: TaskAppEntry) -> None:
|
|
73
|
+
registry.register(entry)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# Register built-in task apps
|
|
78
|
+
try:
|
|
79
|
+
from . import grpo_crafter # noqa: F401
|
|
80
|
+
except Exception:
|
|
81
|
+
# Defer import errors so CLI can report missing deps gracefully
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
from . import math_single_step # noqa: F401
|
|
86
|
+
except Exception:
|
|
87
|
+
# Defer import errors so CLI can report missing deps gracefully
|
|
88
|
+
pass
|