synth-ai 0.2.8.dev13__py3-none-any.whl → 0.2.9.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/api/train/cli.py +21 -0
- synth_ai/api/train/config_finder.py +54 -6
- synth_ai/api/train/task_app.py +70 -5
- synth_ai/cli/rl_demo.py +16 -4
- synth_ai/cli/root.py +36 -5
- synth_ai/cli/task_apps.py +792 -205
- synth_ai/demo_registry.py +258 -0
- synth_ai/demos/core/cli.py +147 -111
- synth_ai/demos/demo_task_apps/__init__.py +7 -1
- synth_ai/demos/demo_task_apps/math/config.toml +55 -110
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +157 -21
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +39 -0
- synth_ai/task/auth.py +33 -12
- synth_ai/task/client.py +20 -3
- {synth_ai-0.2.8.dev13.dist-info → synth_ai-0.2.9.dev1.dist-info}/METADATA +1 -1
- {synth_ai-0.2.8.dev13.dist-info → synth_ai-0.2.9.dev1.dist-info}/RECORD +20 -18
- {synth_ai-0.2.8.dev13.dist-info → synth_ai-0.2.9.dev1.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.8.dev13.dist-info → synth_ai-0.2.9.dev1.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.8.dev13.dist-info → synth_ai-0.2.9.dev1.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.8.dev13.dist-info → synth_ai-0.2.9.dev1.dist-info}/top_level.txt +0 -0
|
@@ -1,129 +1,74 @@
|
|
|
1
|
-
[
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
trainer_mode = "full"
|
|
6
|
-
|
|
7
|
-
[lora]
|
|
8
|
-
r = 16
|
|
9
|
-
alpha = 32
|
|
10
|
-
dropout = 0.05
|
|
11
|
-
target_modules = [
|
|
12
|
-
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
13
|
-
"gate_proj", "up_proj", "down_proj",
|
|
14
|
-
]
|
|
15
|
-
|
|
16
|
-
[rdma]
|
|
17
|
-
enabled = false
|
|
18
|
-
ifname = "eth0"
|
|
19
|
-
ip_type = "ipv4"
|
|
20
|
-
p2p_disable = 0
|
|
21
|
-
shm_disable = 0
|
|
22
|
-
fast_nccl = false
|
|
23
|
-
|
|
24
|
-
gid_index = 3
|
|
25
|
-
cross_nic = 0
|
|
26
|
-
collnet_enable = 0
|
|
27
|
-
net_gdr_level = 2
|
|
28
|
-
|
|
29
|
-
nsocks_perthread = 4
|
|
30
|
-
socket_nthreads = 2
|
|
1
|
+
[algorithm]
|
|
2
|
+
type = "online"
|
|
3
|
+
method = "policy_gradient"
|
|
4
|
+
variety = "gspo"
|
|
31
5
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
p2p_level = "SYS"
|
|
35
|
-
debug = "INFO"
|
|
6
|
+
[services]
|
|
7
|
+
task_url = "http://localhost:8101"
|
|
36
8
|
|
|
37
|
-
[
|
|
38
|
-
|
|
39
|
-
gpu_index = 1
|
|
40
|
-
port = 8002
|
|
41
|
-
tp = 1
|
|
42
|
-
health_max_wait_s = 180
|
|
43
|
-
health_interval_ms = 300
|
|
9
|
+
[model]
|
|
10
|
+
base = "Qwen/Qwen3-1.7B"
|
|
44
11
|
|
|
45
|
-
[
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
12
|
+
[policy]
|
|
13
|
+
model = "Qwen/Qwen3-1.7B"
|
|
14
|
+
inference_url = "http://localhost:8000/api/inference"
|
|
15
|
+
max_tokens = 1028
|
|
16
|
+
temperature = 0.2
|
|
17
|
+
|
|
18
|
+
[data]
|
|
19
|
+
split = "train"
|
|
20
|
+
seed_start = 0
|
|
21
|
+
episodes_per_iteration = 1280 # 8 per group * 4 groups per batch * 2 batches per step * 20 steps
|
|
22
|
+
evaluation_split = "validation"
|
|
23
|
+
evaluation_episodes = 50
|
|
53
24
|
|
|
54
25
|
[training]
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
batch_size =
|
|
58
|
-
group_size =
|
|
26
|
+
max_turns = 1
|
|
27
|
+
ops = ["agent", "env"]
|
|
28
|
+
batch_size = 2
|
|
29
|
+
group_size = 16
|
|
30
|
+
reward_positive = 1.0
|
|
31
|
+
reward_negative_no_tool = -1.0
|
|
32
|
+
reward_negative_no_answer = -0.5
|
|
59
33
|
learning_rate = 5e-6
|
|
60
|
-
max_grad_norm = 0.5
|
|
61
34
|
log_interval = 1
|
|
62
|
-
update_reference_interval = 0
|
|
63
35
|
weight_sync_interval = 1
|
|
64
36
|
|
|
65
37
|
[training.weight_sync]
|
|
66
38
|
enable = true
|
|
67
39
|
targets = ["policy"]
|
|
68
40
|
|
|
69
|
-
[
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
max_model_len = 8192
|
|
73
|
-
max_num_seqs = 32
|
|
74
|
-
enforce_eager = false
|
|
75
|
-
max_parallel_generations = 4
|
|
41
|
+
[compute]
|
|
42
|
+
gpu_type = "H100"
|
|
43
|
+
gpu_count = 4
|
|
76
44
|
|
|
77
|
-
[
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
every_n_iters = 5
|
|
84
|
-
|
|
85
|
-
[rollout]
|
|
86
|
-
env_name = "math"
|
|
87
|
-
policy_name = "math-react"
|
|
88
|
-
env_config = {}
|
|
89
|
-
max_steps_per_episode = 5
|
|
90
|
-
sampling_temperature = 0.3
|
|
91
|
-
sampling_top_p = 0.95
|
|
92
|
-
max_tokens = 1024
|
|
93
|
-
max_concurrent_rollouts = 4
|
|
94
|
-
ops_per_rollout = 14
|
|
95
|
-
on_done = "reset"
|
|
96
|
-
thinking_mode = "think"
|
|
97
|
-
thinking_budget = 512
|
|
45
|
+
[topology]
|
|
46
|
+
type = "single_node_split"
|
|
47
|
+
gpus_for_vllm = 2
|
|
48
|
+
gpus_for_training = 1
|
|
49
|
+
gpus_for_ref = 1
|
|
50
|
+
tensor_parallel = 1
|
|
98
51
|
|
|
99
|
-
[
|
|
100
|
-
|
|
52
|
+
[vllm]
|
|
53
|
+
tensor_parallel_size = 1
|
|
54
|
+
max_model_len = 4096
|
|
101
55
|
|
|
102
|
-
[
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
advantage_normalization = true
|
|
109
|
-
group_normalization = true
|
|
110
|
-
num_inner_steps = 1
|
|
111
|
-
clip_epsilon = 0.2
|
|
112
|
-
completion_only = false
|
|
56
|
+
[reference]
|
|
57
|
+
placement = "dedicated"
|
|
58
|
+
port = 8002
|
|
59
|
+
tp = 1
|
|
60
|
+
health_max_wait_s = 180
|
|
61
|
+
health_interval_ms = 300
|
|
113
62
|
|
|
114
|
-
[
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
indicator_lambda = 0.0
|
|
63
|
+
[rollout]
|
|
64
|
+
policy_name = "math-single-step"
|
|
65
|
+
max_turns = 1
|
|
66
|
+
episodes_per_batch = 32 # group_size * batch_size
|
|
119
67
|
|
|
120
|
-
[
|
|
121
|
-
|
|
68
|
+
[evaluation]
|
|
69
|
+
instances = 32
|
|
70
|
+
every_n_iters = 10
|
|
71
|
+
seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
|
|
122
72
|
|
|
123
|
-
[
|
|
124
|
-
|
|
125
|
-
directory = "/checkpoints"
|
|
126
|
-
keep_last_n = 3
|
|
127
|
-
save_optimizer = true
|
|
128
|
-
save_scheduler = true
|
|
129
|
-
enabled = true
|
|
73
|
+
[tags]
|
|
74
|
+
experiment = "math_single_step_qwen17"
|
|
@@ -7,6 +7,9 @@ from pathlib import Path
|
|
|
7
7
|
|
|
8
8
|
from modal import App, Image, Secret, asgi_app
|
|
9
9
|
from functools import lru_cache
|
|
10
|
+
from typing import Iterable
|
|
11
|
+
|
|
12
|
+
from starlette.requests import Request
|
|
10
13
|
|
|
11
14
|
try: # Backward compatibility with older installed SDKs
|
|
12
15
|
from synth_ai.demos.demo_task_apps.core import DEFAULT_TASK_APP_SECRET_NAME
|
|
@@ -95,18 +98,77 @@ app = App("hendrycks-math-task-app")
|
|
|
95
98
|
def fastapi_app():
|
|
96
99
|
import httpx
|
|
97
100
|
from fastapi import Body, HTTPException, status
|
|
98
|
-
from fastapi import FastAPI
|
|
101
|
+
from fastapi import FastAPI
|
|
99
102
|
from fastapi.middleware.cors import CORSMiddleware
|
|
100
103
|
from fastapi.responses import JSONResponse
|
|
104
|
+
try:
|
|
105
|
+
from synth_ai.task.auth import (
|
|
106
|
+
is_api_key_header_authorized,
|
|
107
|
+
normalize_environment_api_key,
|
|
108
|
+
)
|
|
109
|
+
except Exception: # pragma: no cover - fallback for older synth-ai builds
|
|
110
|
+
def _normalize_env_key_fallback() -> str | None:
|
|
111
|
+
key = os.getenv("ENVIRONMENT_API_KEY")
|
|
112
|
+
if key:
|
|
113
|
+
return key
|
|
114
|
+
for alias in ("dev_environment_api_key", "DEV_ENVIRONMENT_API_KEY"):
|
|
115
|
+
candidate = os.getenv(alias)
|
|
116
|
+
if candidate:
|
|
117
|
+
os.environ["ENVIRONMENT_API_KEY"] = candidate
|
|
118
|
+
return candidate
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
def normalize_environment_api_key() -> str | None: # type: ignore[override]
|
|
122
|
+
return _normalize_env_key_fallback()
|
|
123
|
+
|
|
124
|
+
def _header_values(request: Request, header: str) -> Iterable[str]:
|
|
125
|
+
raw = request.headers.get(header) or request.headers.get(header.lower())
|
|
126
|
+
return [raw] if raw is not None else []
|
|
127
|
+
|
|
128
|
+
def _split(values: Iterable[str]) -> list[str]:
|
|
129
|
+
parts: list[str] = []
|
|
130
|
+
for value in values:
|
|
131
|
+
if not isinstance(value, str):
|
|
132
|
+
continue
|
|
133
|
+
for chunk in value.split(','):
|
|
134
|
+
chunk = chunk.strip()
|
|
135
|
+
if chunk:
|
|
136
|
+
parts.append(chunk)
|
|
137
|
+
return parts
|
|
138
|
+
|
|
139
|
+
def is_api_key_header_authorized(request: Request) -> bool: # type: ignore[override]
|
|
140
|
+
expected = normalize_environment_api_key()
|
|
141
|
+
if not expected:
|
|
142
|
+
return False
|
|
143
|
+
single = _header_values(request, "x-api-key")
|
|
144
|
+
multi = _header_values(request, "x-api-keys")
|
|
145
|
+
auth = _header_values(request, "authorization")
|
|
146
|
+
bearer = []
|
|
147
|
+
for token in auth:
|
|
148
|
+
if isinstance(token, str) and token.lower().startswith("bearer "):
|
|
149
|
+
bearer.append(token.split(" ", 1)[1].strip())
|
|
150
|
+
candidates = _split(single + multi + bearer)
|
|
151
|
+
return any(candidate == expected for candidate in candidates)
|
|
101
152
|
|
|
102
153
|
# Inline, self-contained FastAPI app (math-only)
|
|
103
154
|
@lru_cache(maxsize=1)
|
|
104
155
|
def _hf_split(subject: str, split: str, slice_spec: str | None = None):
|
|
105
156
|
from datasets import load_dataset # type: ignore
|
|
157
|
+
|
|
106
158
|
s = split
|
|
107
159
|
if slice_spec:
|
|
108
160
|
s = f"{s}{slice_spec}"
|
|
109
|
-
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
return load_dataset("nlile/hendrycks-MATH-benchmark", subject, split=s)
|
|
164
|
+
except ValueError:
|
|
165
|
+
base = load_dataset("nlile/hendrycks-MATH-benchmark", split=s)
|
|
166
|
+
if subject and subject not in {"", "default"}:
|
|
167
|
+
if "subject" in base.column_names:
|
|
168
|
+
base = base.filter(lambda ex: ex.get("subject") == subject)
|
|
169
|
+
elif isinstance(base, list):
|
|
170
|
+
base = [ex for ex in base if ex.get("subject") == subject]
|
|
171
|
+
return base
|
|
110
172
|
|
|
111
173
|
def _normalize_answer_text(s: str) -> str:
|
|
112
174
|
import re as _re
|
|
@@ -121,6 +183,9 @@ def fastapi_app():
|
|
|
121
183
|
subj = subject or os.getenv("HENDRYCKS_MATH_CONFIG", "default")
|
|
122
184
|
ds = _hf_split(subj, os.getenv("HENDRYCKS_MATH_SPLIT", "test"), os.getenv("HENDRYCKS_MATH_SLICE"))
|
|
123
185
|
n = len(ds) if hasattr(ds, "__len__") else 0
|
|
186
|
+
if n == 0 and subject not in {"", "default"}:
|
|
187
|
+
ds = _hf_split("default", os.getenv("HENDRYCKS_MATH_SPLIT", "test"), os.getenv("HENDRYCKS_MATH_SLICE"))
|
|
188
|
+
n = len(ds) if hasattr(ds, "__len__") else 0
|
|
124
189
|
if n == 0:
|
|
125
190
|
raise RuntimeError("Hendrycks MATH dataset loaded empty")
|
|
126
191
|
idx = abs(int(seed)) % n
|
|
@@ -158,6 +223,53 @@ def fastapi_app():
|
|
|
158
223
|
logger.info(msg)
|
|
159
224
|
return prefix
|
|
160
225
|
|
|
226
|
+
def _resolve_env_keys() -> set[str]:
|
|
227
|
+
keys: set[str] = set()
|
|
228
|
+
for alias in ("ENVIRONMENT_API_KEY", "dev_environment_api_key", "DEV_ENVIRONMENT_API_KEY"):
|
|
229
|
+
value = os.environ.get(alias)
|
|
230
|
+
if value:
|
|
231
|
+
os.environ.setdefault("ENVIRONMENT_API_KEY", value)
|
|
232
|
+
keys.add(value)
|
|
233
|
+
alias_env = os.environ.get("ENVIRONMENT_API_KEY_ALIASES", "")
|
|
234
|
+
for chunk in alias_env.split(","):
|
|
235
|
+
trimmed = chunk.strip()
|
|
236
|
+
if trimmed:
|
|
237
|
+
keys.add(trimmed)
|
|
238
|
+
return keys
|
|
239
|
+
|
|
240
|
+
def _extract_header_candidates(
|
|
241
|
+
request: Request,
|
|
242
|
+
x_api_key: str | None,
|
|
243
|
+
x_api_keys: str | None,
|
|
244
|
+
authorization: str | None,
|
|
245
|
+
) -> list[str]:
|
|
246
|
+
headers = request.headers
|
|
247
|
+
candidates: list[str] = []
|
|
248
|
+
primary = x_api_key or headers.get("x-api-key")
|
|
249
|
+
if primary:
|
|
250
|
+
candidates.append(primary.strip())
|
|
251
|
+
secondary = x_api_keys or headers.get("x-api-keys")
|
|
252
|
+
if secondary:
|
|
253
|
+
candidates.extend([value.strip() for value in secondary.split(",") if value.strip()])
|
|
254
|
+
auth_header = authorization or headers.get("authorization") or headers.get("Authorization")
|
|
255
|
+
if auth_header and auth_header.lower().startswith("bearer "):
|
|
256
|
+
token = auth_header.split(" ", 1)[1].strip()
|
|
257
|
+
if token:
|
|
258
|
+
candidates.append(token)
|
|
259
|
+
return [c for c in candidates if c]
|
|
260
|
+
|
|
261
|
+
def _is_authorized(
|
|
262
|
+
request: Request,
|
|
263
|
+
x_api_key: str | None,
|
|
264
|
+
x_api_keys: str | None,
|
|
265
|
+
authorization: str | None,
|
|
266
|
+
) -> bool:
|
|
267
|
+
keys = _resolve_env_keys()
|
|
268
|
+
if not keys:
|
|
269
|
+
return False
|
|
270
|
+
candidates = _extract_header_candidates(request, x_api_key, x_api_keys, authorization)
|
|
271
|
+
return any(candidate in keys for candidate in candidates)
|
|
272
|
+
|
|
161
273
|
@app.get("/info")
|
|
162
274
|
async def info():
|
|
163
275
|
return {
|
|
@@ -166,42 +278,47 @@ def fastapi_app():
|
|
|
166
278
|
}
|
|
167
279
|
|
|
168
280
|
@app.get("/health")
|
|
169
|
-
async def health(
|
|
170
|
-
|
|
281
|
+
async def health(request: Request):
|
|
282
|
+
env_keys = _resolve_env_keys()
|
|
283
|
+
env_key = next(iter(env_keys), None)
|
|
171
284
|
if not env_key:
|
|
172
285
|
return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
|
|
173
|
-
|
|
286
|
+
# Authorize using all header variants; avoid typed Header params to prevent 422s
|
|
287
|
+
authorized = is_api_key_header_authorized(request)
|
|
288
|
+
if not authorized:
|
|
174
289
|
prefix = _log_env_key_prefix("health", env_key)
|
|
175
|
-
content = {
|
|
176
|
-
|
|
290
|
+
content = {
|
|
291
|
+
"status": "healthy",
|
|
292
|
+
"authorized": False,
|
|
293
|
+
}
|
|
177
294
|
if prefix:
|
|
178
|
-
content["detail"] = f"Invalid API key (expected prefix: {prefix})"
|
|
179
295
|
content["expected_api_key_prefix"] = prefix
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
return {"status": "healthy"}
|
|
296
|
+
return JSONResponse(status_code=200, content=content)
|
|
297
|
+
return {"status": "healthy", "authorized": True}
|
|
183
298
|
|
|
184
299
|
# Optional rollout-specific health for CLI compatibility
|
|
185
300
|
@app.get("/health/rollout")
|
|
186
|
-
async def health_rollout(
|
|
187
|
-
|
|
301
|
+
async def health_rollout(request: Request):
|
|
302
|
+
env_keys = _resolve_env_keys()
|
|
303
|
+
env_key = next(iter(env_keys), None)
|
|
188
304
|
if not env_key:
|
|
189
305
|
return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
|
|
190
|
-
|
|
306
|
+
authorized = is_api_key_header_authorized(request)
|
|
307
|
+
if not authorized:
|
|
191
308
|
prefix = _log_env_key_prefix("health/rollout", env_key)
|
|
192
|
-
content = {
|
|
193
|
-
|
|
309
|
+
content = {
|
|
310
|
+
"status": "healthy",
|
|
311
|
+
"authorized": False,
|
|
312
|
+
}
|
|
194
313
|
if prefix:
|
|
195
|
-
content["detail"] = f"Invalid or missing API key (expected prefix: {prefix})"
|
|
196
314
|
content["expected_api_key_prefix"] = prefix
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
return {"ok": True}
|
|
315
|
+
return JSONResponse(status_code=200, content=content)
|
|
316
|
+
return {"ok": True, "authorized": True}
|
|
200
317
|
|
|
201
318
|
# _load_hendrycks_problem is defined at fastapi_app scope
|
|
202
319
|
|
|
203
320
|
@app.get("/task_info")
|
|
204
|
-
async def task_info(seed: int = 0, subject: str = "
|
|
321
|
+
async def task_info(seed: int = 0, subject: str = "default"):
|
|
205
322
|
"""Return Hendrycks MATH problem/answer and tool schema for a seed."""
|
|
206
323
|
q, a = _load_hendrycks_problem(int(seed), subject=subject)
|
|
207
324
|
tools = [{
|
|
@@ -229,6 +346,25 @@ def fastapi_app():
|
|
|
229
346
|
|
|
230
347
|
api = create_app()
|
|
231
348
|
|
|
349
|
+
# Always log and surface 422 validation errors with header presence snapshot
|
|
350
|
+
from fastapi.exceptions import RequestValidationError
|
|
351
|
+
|
|
352
|
+
@api.exception_handler(RequestValidationError)
|
|
353
|
+
async def _on_validation_error(request: Request, exc: RequestValidationError):
|
|
354
|
+
try:
|
|
355
|
+
hdr = request.headers
|
|
356
|
+
snapshot = {
|
|
357
|
+
"path": str(getattr(request, "url").path),
|
|
358
|
+
"have_x_api_key": bool(hdr.get("x-api-key")),
|
|
359
|
+
"have_x_api_keys": bool(hdr.get("x-api-keys")),
|
|
360
|
+
"have_authorization": bool(hdr.get("authorization")),
|
|
361
|
+
"errors": exc.errors()[:5],
|
|
362
|
+
}
|
|
363
|
+
print("[422] validation", snapshot, flush=True)
|
|
364
|
+
except Exception:
|
|
365
|
+
pass
|
|
366
|
+
return JSONResponse(status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]})
|
|
367
|
+
|
|
232
368
|
@api.get("/")
|
|
233
369
|
async def root_probe():
|
|
234
370
|
return {"status": "ok", "service": "math"}
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Task app registry entry for the math demo Modal deployment."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
6
|
+
from synth_ai.task.apps.math_single_step import build_config as base_build_config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
DEMO_MODAL_CONFIG = ModalDeploymentConfig(
|
|
10
|
+
app_name="hendrycks-math-task-app",
|
|
11
|
+
pip_packages=(
|
|
12
|
+
"fastapi>=0.110.0",
|
|
13
|
+
"uvicorn>=0.23.0",
|
|
14
|
+
"pydantic>=2.6.0",
|
|
15
|
+
"httpx>=0.24.0",
|
|
16
|
+
"numpy>=1.24.0",
|
|
17
|
+
"aiohttp>=3.8.0",
|
|
18
|
+
"datasets>=2.16.0",
|
|
19
|
+
"synth-ai",
|
|
20
|
+
),
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_config():
|
|
25
|
+
"""Reuse the shared math single-step TaskAppConfig."""
|
|
26
|
+
|
|
27
|
+
return base_build_config()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
register_task_app(
|
|
31
|
+
entry=TaskAppEntry(
|
|
32
|
+
app_id="hendrycks-math-demo",
|
|
33
|
+
description="Demo math task app (Modal-focused) shipping with synth-ai demos.",
|
|
34
|
+
config_factory=build_config,
|
|
35
|
+
env_files=("examples/rl/.env",),
|
|
36
|
+
modal=DEMO_MODAL_CONFIG,
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
|
synth_ai/task/auth.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
"""Authentication helpers shared by Task Apps."""
|
|
4
4
|
|
|
5
5
|
import os
|
|
6
|
-
from typing import Iterable, Optional, Any
|
|
6
|
+
from typing import Iterable, Optional, Any, Set
|
|
7
7
|
|
|
8
8
|
from .errors import http_exception
|
|
9
9
|
|
|
@@ -12,6 +12,7 @@ _DEV_API_KEY_ENVS = ("dev_environment_api_key", "DEV_ENVIRONMENT_API_KEY")
|
|
|
12
12
|
_API_KEY_HEADER = "x-api-key"
|
|
13
13
|
_API_KEYS_HEADER = "x-api-keys"
|
|
14
14
|
_AUTH_HEADER = "authorization"
|
|
15
|
+
_API_KEY_ALIASES_ENV = "ENVIRONMENT_API_KEY_ALIASES" # comma-separated list of additional valid keys
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def _mask(value: str, *, prefix: int = 4) -> str:
|
|
@@ -42,6 +43,26 @@ def normalize_environment_api_key() -> Optional[str]:
|
|
|
42
43
|
return None
|
|
43
44
|
|
|
44
45
|
|
|
46
|
+
def allowed_environment_api_keys() -> Set[str]:
|
|
47
|
+
"""Return the set of valid environment API keys for this Task App.
|
|
48
|
+
|
|
49
|
+
Includes:
|
|
50
|
+
- The primary ENVIRONMENT_API_KEY (normalized from dev fallbacks if needed)
|
|
51
|
+
- Any comma-separated aliases from ENVIRONMENT_API_KEY_ALIASES
|
|
52
|
+
"""
|
|
53
|
+
keys: set[str] = set()
|
|
54
|
+
primary = normalize_environment_api_key()
|
|
55
|
+
if primary:
|
|
56
|
+
keys.add(primary)
|
|
57
|
+
aliases = (os.getenv(_API_KEY_ALIASES_ENV) or "").strip()
|
|
58
|
+
if aliases:
|
|
59
|
+
for part in aliases.split(","):
|
|
60
|
+
trimmed = part.strip()
|
|
61
|
+
if trimmed:
|
|
62
|
+
keys.add(trimmed)
|
|
63
|
+
return keys
|
|
64
|
+
|
|
65
|
+
|
|
45
66
|
def _header_values(request: Any, header: str) -> Iterable[str]:
|
|
46
67
|
header_lower = header.lower()
|
|
47
68
|
if request is None:
|
|
@@ -78,10 +99,10 @@ def _split_csv(values: Iterable[str]) -> list[str]:
|
|
|
78
99
|
|
|
79
100
|
|
|
80
101
|
def is_api_key_header_authorized(request: Any) -> bool:
|
|
81
|
-
"""Return True if
|
|
102
|
+
"""Return True if any header-provided key matches any allowed environment key."""
|
|
82
103
|
|
|
83
|
-
|
|
84
|
-
if not
|
|
104
|
+
allowed = allowed_environment_api_keys()
|
|
105
|
+
if not allowed:
|
|
85
106
|
return False
|
|
86
107
|
single = list(_header_values(request, _API_KEY_HEADER))
|
|
87
108
|
multi = list(_header_values(request, _API_KEYS_HEADER))
|
|
@@ -91,14 +112,14 @@ def is_api_key_header_authorized(request: Any) -> bool:
|
|
|
91
112
|
if isinstance(a, str) and a.lower().startswith("bearer "):
|
|
92
113
|
bearer.append(a.split(" ", 1)[1].strip())
|
|
93
114
|
candidates = _split_csv(single + multi + bearer)
|
|
94
|
-
return any(candidate
|
|
115
|
+
return any(candidate in allowed for candidate in candidates)
|
|
95
116
|
|
|
96
117
|
|
|
97
118
|
def require_api_key_dependency(request: Any) -> None:
|
|
98
119
|
"""FastAPI dependency enforcing Task App authentication headers."""
|
|
99
120
|
|
|
100
|
-
|
|
101
|
-
if not
|
|
121
|
+
allowed = allowed_environment_api_keys()
|
|
122
|
+
if not allowed:
|
|
102
123
|
raise http_exception(503, "missing_environment_api_key", "ENVIRONMENT_API_KEY is not configured")
|
|
103
124
|
# Build candidate list for verbose diagnostics
|
|
104
125
|
single = list(_header_values(request, _API_KEY_HEADER))
|
|
@@ -109,12 +130,12 @@ def require_api_key_dependency(request: Any) -> None:
|
|
|
109
130
|
if isinstance(a, str) and a.lower().startswith("bearer "):
|
|
110
131
|
bearer.append(a.split(" ", 1)[1].strip())
|
|
111
132
|
candidates = _split_csv(single + multi + bearer)
|
|
112
|
-
if
|
|
133
|
+
if not any(candidate in allowed for candidate in candidates):
|
|
113
134
|
try:
|
|
114
135
|
print({
|
|
115
136
|
"task_auth_failed": True,
|
|
116
|
-
"
|
|
117
|
-
"
|
|
137
|
+
"allowed_first15": [k[:15] for k in allowed],
|
|
138
|
+
"allowed_count": len(allowed),
|
|
118
139
|
"got_first15": [c[:15] for c in candidates],
|
|
119
140
|
"got_lens": [len(c) for c in candidates],
|
|
120
141
|
"have_x_api_key": bool(single),
|
|
@@ -125,8 +146,8 @@ def require_api_key_dependency(request: Any) -> None:
|
|
|
125
146
|
pass
|
|
126
147
|
# Use 400 to make failures unmistakable during preflight
|
|
127
148
|
raise http_exception(400, "unauthorised", "API key missing or invalid", extra={
|
|
128
|
-
"
|
|
129
|
-
"
|
|
149
|
+
"allowed_first15": [k[:15] for k in allowed],
|
|
150
|
+
"allowed_count": len(allowed),
|
|
130
151
|
"got_first15": [c[:15] for c in candidates],
|
|
131
152
|
"got_lens": [len(c) for c in candidates],
|
|
132
153
|
})
|
synth_ai/task/client.py
CHANGED
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import asyncio
|
|
6
6
|
from typing import Any, Dict, Iterable, List, Optional
|
|
7
|
+
import os
|
|
7
8
|
|
|
8
9
|
import httpx
|
|
9
10
|
from pydantic import BaseModel
|
|
@@ -54,8 +55,24 @@ class TaskAppClient:
|
|
|
54
55
|
|
|
55
56
|
def _headers(self) -> Dict[str, str]:
|
|
56
57
|
headers: Dict[str, str] = {}
|
|
57
|
-
|
|
58
|
-
|
|
58
|
+
# Primary key
|
|
59
|
+
primary = (self.api_key or "").strip()
|
|
60
|
+
if primary:
|
|
61
|
+
headers["X-API-Key"] = primary
|
|
62
|
+
# Also set Authorization for clients that read bearer tokens
|
|
63
|
+
headers.setdefault("Authorization", f"Bearer {primary}")
|
|
64
|
+
# Include ALL available environment keys via CSV in X-API-Keys
|
|
65
|
+
keys: list[str] = []
|
|
66
|
+
if primary:
|
|
67
|
+
keys.append(primary)
|
|
68
|
+
aliases = (os.getenv("ENVIRONMENT_API_KEY_ALIASES") or "").strip()
|
|
69
|
+
if aliases:
|
|
70
|
+
for part in aliases.split(","):
|
|
71
|
+
trimmed = part.strip()
|
|
72
|
+
if trimmed and trimmed not in keys:
|
|
73
|
+
keys.append(trimmed)
|
|
74
|
+
if keys:
|
|
75
|
+
headers["X-API-Keys"] = ",".join(keys)
|
|
59
76
|
return headers
|
|
60
77
|
|
|
61
78
|
async def aclose(self) -> None:
|
|
@@ -68,7 +85,7 @@ class TaskAppClient:
|
|
|
68
85
|
method: str,
|
|
69
86
|
path: str,
|
|
70
87
|
*,
|
|
71
|
-
params: Optional[
|
|
88
|
+
params: Optional[Dict[str, Any] | List[tuple[str, Any]]] = None,
|
|
72
89
|
json_payload: Any = None,
|
|
73
90
|
) -> httpx.Response:
|
|
74
91
|
client = await self._ensure_client()
|