synth-ai 0.2.8.dev12__py3-none-any.whl → 0.2.9.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/api/train/__init__.py +5 -0
- synth_ai/api/train/builders.py +165 -0
- synth_ai/api/train/cli.py +450 -0
- synth_ai/api/train/config_finder.py +168 -0
- synth_ai/api/train/env_resolver.py +302 -0
- synth_ai/api/train/pollers.py +66 -0
- synth_ai/api/train/task_app.py +193 -0
- synth_ai/api/train/utils.py +232 -0
- synth_ai/cli/__init__.py +23 -0
- synth_ai/cli/rl_demo.py +18 -6
- synth_ai/cli/root.py +38 -6
- synth_ai/cli/task_apps.py +1107 -0
- synth_ai/demo_registry.py +258 -0
- synth_ai/demos/core/cli.py +147 -111
- synth_ai/demos/demo_task_apps/__init__.py +7 -1
- synth_ai/demos/demo_task_apps/math/config.toml +55 -110
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +157 -21
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +39 -0
- synth_ai/task/__init__.py +94 -1
- synth_ai/task/apps/__init__.py +88 -0
- synth_ai/task/apps/grpo_crafter.py +438 -0
- synth_ai/task/apps/math_single_step.py +852 -0
- synth_ai/task/auth.py +153 -0
- synth_ai/task/client.py +165 -0
- synth_ai/task/contracts.py +29 -14
- synth_ai/task/datasets.py +105 -0
- synth_ai/task/errors.py +49 -0
- synth_ai/task/json.py +77 -0
- synth_ai/task/proxy.py +258 -0
- synth_ai/task/rubrics.py +212 -0
- synth_ai/task/server.py +398 -0
- synth_ai/task/tracing_utils.py +79 -0
- synth_ai/task/vendors.py +61 -0
- synth_ai/tracing_v3/session_tracer.py +13 -5
- synth_ai/tracing_v3/storage/base.py +10 -12
- synth_ai/tracing_v3/turso/manager.py +20 -6
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/METADATA +3 -2
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/RECORD +42 -18
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/top_level.txt +0 -0
synth_ai/task/server.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""FastAPI scaffolding for Task Apps (local dev + deployment)."""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import inspect
|
|
7
|
+
import os
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Awaitable, Callable, Iterable, Mapping, MutableMapping, Sequence
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
from fastapi import APIRouter, Depends, FastAPI, Query, Request
|
|
14
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
15
|
+
from fastapi.responses import JSONResponse
|
|
16
|
+
from starlette.middleware import Middleware
|
|
17
|
+
|
|
18
|
+
from .auth import (
|
|
19
|
+
is_api_key_header_authorized,
|
|
20
|
+
normalize_environment_api_key,
|
|
21
|
+
require_api_key_dependency,
|
|
22
|
+
)
|
|
23
|
+
from .contracts import RolloutRequest, RolloutResponse, TaskInfo
|
|
24
|
+
from .datasets import TaskDatasetRegistry
|
|
25
|
+
from .errors import http_exception
|
|
26
|
+
from .json import to_jsonable
|
|
27
|
+
from .proxy import (
|
|
28
|
+
prepare_for_groq,
|
|
29
|
+
prepare_for_openai,
|
|
30
|
+
inject_system_hint,
|
|
31
|
+
synthesize_tool_call_if_missing,
|
|
32
|
+
)
|
|
33
|
+
from .rubrics import Rubric
|
|
34
|
+
from .vendors import get_groq_key_or_503, get_openai_key_or_503, normalize_vendor_keys
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
TasksetDescriptor = Callable[[], Mapping[str, Any] | Awaitable[Mapping[str, Any]]]
|
|
38
|
+
InstanceProvider = Callable[[Sequence[int]], Iterable[TaskInfo] | Awaitable[Iterable[TaskInfo]]]
|
|
39
|
+
RolloutExecutor = Callable[[RolloutRequest, Request], Any | Awaitable[Any]]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(slots=True)
|
|
43
|
+
class RubricBundle:
|
|
44
|
+
"""Optional rubrics advertised by the task app."""
|
|
45
|
+
|
|
46
|
+
outcome: Rubric | None = None
|
|
47
|
+
events: Rubric | None = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass(slots=True)
|
|
51
|
+
class ProxyConfig:
|
|
52
|
+
"""Configuration for optional vendor proxy endpoints."""
|
|
53
|
+
|
|
54
|
+
enable_openai: bool = False
|
|
55
|
+
enable_groq: bool = False
|
|
56
|
+
system_hint: str | None = None
|
|
57
|
+
openai_url: str = "https://api.openai.com/v1/chat/completions"
|
|
58
|
+
groq_url: str = "https://api.groq.com/openai/v1/chat/completions"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass(slots=True)
|
|
62
|
+
class TaskAppConfig:
|
|
63
|
+
"""Declarative configuration describing a Task App."""
|
|
64
|
+
|
|
65
|
+
app_id: str
|
|
66
|
+
name: str
|
|
67
|
+
description: str
|
|
68
|
+
base_task_info: TaskInfo
|
|
69
|
+
describe_taskset: TasksetDescriptor
|
|
70
|
+
provide_task_instances: InstanceProvider
|
|
71
|
+
rollout: RolloutExecutor
|
|
72
|
+
dataset_registry: TaskDatasetRegistry | None = None
|
|
73
|
+
rubrics: RubricBundle = field(default_factory=RubricBundle)
|
|
74
|
+
proxy: ProxyConfig | None = None
|
|
75
|
+
routers: Sequence[APIRouter] = field(default_factory=tuple)
|
|
76
|
+
middleware: Sequence[Middleware] = field(default_factory=tuple)
|
|
77
|
+
app_state: Mapping[str, Any] = field(default_factory=dict)
|
|
78
|
+
require_api_key: bool = True
|
|
79
|
+
expose_debug_env: bool = True
|
|
80
|
+
cors_origins: Sequence[str] | None = None
|
|
81
|
+
startup_hooks: Sequence[Callable[[], None | Awaitable[None]]] = field(default_factory=tuple)
|
|
82
|
+
shutdown_hooks: Sequence[Callable[[], None | Awaitable[None]]] = field(default_factory=tuple)
|
|
83
|
+
|
|
84
|
+
def clone(self) -> "TaskAppConfig":
|
|
85
|
+
"""Return a shallow copy safe to mutate when wiring the app."""
|
|
86
|
+
|
|
87
|
+
return TaskAppConfig(
|
|
88
|
+
app_id=self.app_id,
|
|
89
|
+
name=self.name,
|
|
90
|
+
description=self.description,
|
|
91
|
+
base_task_info=self.base_task_info,
|
|
92
|
+
describe_taskset=self.describe_taskset,
|
|
93
|
+
provide_task_instances=self.provide_task_instances,
|
|
94
|
+
rollout=self.rollout,
|
|
95
|
+
dataset_registry=self.dataset_registry,
|
|
96
|
+
rubrics=self.rubrics,
|
|
97
|
+
proxy=self.proxy,
|
|
98
|
+
routers=tuple(self.routers),
|
|
99
|
+
middleware=tuple(self.middleware),
|
|
100
|
+
app_state=dict(self.app_state),
|
|
101
|
+
require_api_key=self.require_api_key,
|
|
102
|
+
expose_debug_env=self.expose_debug_env,
|
|
103
|
+
cors_origins=tuple(self.cors_origins or ()),
|
|
104
|
+
startup_hooks=tuple(self.startup_hooks),
|
|
105
|
+
shutdown_hooks=tuple(self.shutdown_hooks),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _maybe_await(result: Any) -> Awaitable[Any]:
|
|
110
|
+
if inspect.isawaitable(result):
|
|
111
|
+
return asyncio.ensure_future(result)
|
|
112
|
+
loop = asyncio.get_event_loop()
|
|
113
|
+
future: asyncio.Future[Any] = loop.create_future()
|
|
114
|
+
future.set_result(result)
|
|
115
|
+
return future
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _ensure_task_info(obj: Any) -> TaskInfo:
|
|
119
|
+
if isinstance(obj, TaskInfo):
|
|
120
|
+
return obj
|
|
121
|
+
if isinstance(obj, MutableMapping):
|
|
122
|
+
return TaskInfo.model_validate(obj)
|
|
123
|
+
raise TypeError(f"Task instance provider must yield TaskInfo-compatible objects (got {type(obj)!r})")
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _normalise_seeds(values: Sequence[int]) -> list[int]:
|
|
127
|
+
seeds: list[int] = []
|
|
128
|
+
for value in values:
|
|
129
|
+
try:
|
|
130
|
+
seeds.append(int(value))
|
|
131
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
132
|
+
raise ValueError(f"Seed values must be convertible to int (got {value!r})") from exc
|
|
133
|
+
return seeds
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _build_proxy_routes(
|
|
137
|
+
app: FastAPI, config: TaskAppConfig, auth_dependency: Callable[[Request], None]
|
|
138
|
+
) -> None:
|
|
139
|
+
proxy = config.proxy
|
|
140
|
+
if not proxy:
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
async def _call_vendor(url: str, payload: dict[str, Any], headers: dict[str, str]) -> dict[str, Any]:
|
|
144
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(600.0), follow_redirects=True) as client:
|
|
145
|
+
response = await client.post(url, json=payload, headers=headers)
|
|
146
|
+
data = (
|
|
147
|
+
response.json()
|
|
148
|
+
if response.headers.get("content-type", "").startswith("application/json")
|
|
149
|
+
else {"raw": response.text}
|
|
150
|
+
)
|
|
151
|
+
if response.status_code >= 400:
|
|
152
|
+
code = "vendor_error"
|
|
153
|
+
if url.startswith("https://api.openai.com"):
|
|
154
|
+
code = "openai_error"
|
|
155
|
+
elif "groq" in url:
|
|
156
|
+
code = "groq_error"
|
|
157
|
+
raise http_exception(
|
|
158
|
+
response.status_code,
|
|
159
|
+
code,
|
|
160
|
+
"Vendor proxy error",
|
|
161
|
+
extra={"status": response.status_code, "body": data},
|
|
162
|
+
)
|
|
163
|
+
return data
|
|
164
|
+
|
|
165
|
+
def _log_proxy(route: str, payload: dict[str, Any]) -> None:
|
|
166
|
+
try:
|
|
167
|
+
messages = payload.get("messages") if isinstance(payload, dict) else None
|
|
168
|
+
msg_count = len(messages) if isinstance(messages, list) else 0
|
|
169
|
+
tool_count = len(payload.get("tools") or []) if isinstance(payload, dict) else 0
|
|
170
|
+
model = payload.get("model") if isinstance(payload, dict) else None
|
|
171
|
+
print(f"[task:proxy:{route}] model={model} messages={msg_count} tools={tool_count}", flush=True)
|
|
172
|
+
except Exception: # pragma: no cover - best effort logging
|
|
173
|
+
pass
|
|
174
|
+
|
|
175
|
+
system_hint = proxy.system_hint
|
|
176
|
+
|
|
177
|
+
if proxy.enable_openai:
|
|
178
|
+
@app.post("/proxy/v1/chat/completions", dependencies=[Depends(auth_dependency)])
|
|
179
|
+
async def proxy_openai(body: dict[str, Any], request: Request) -> Any: # type: ignore[no-redef]
|
|
180
|
+
key = get_openai_key_or_503()
|
|
181
|
+
model = body.get("model") if isinstance(body.get("model"), str) else None
|
|
182
|
+
payload = prepare_for_openai(model, body)
|
|
183
|
+
payload = inject_system_hint(payload, system_hint or "")
|
|
184
|
+
_log_proxy("openai", payload)
|
|
185
|
+
data = await _call_vendor(proxy.openai_url, payload, {"Authorization": f"Bearer {key}"})
|
|
186
|
+
sanitized = synthesize_tool_call_if_missing(data)
|
|
187
|
+
return to_jsonable(sanitized)
|
|
188
|
+
|
|
189
|
+
if proxy.enable_groq:
|
|
190
|
+
@app.post("/proxy/groq/v1/chat/completions", dependencies=[Depends(auth_dependency)])
|
|
191
|
+
async def proxy_groq(body: dict[str, Any], request: Request) -> Any: # type: ignore[no-redef]
|
|
192
|
+
key = get_groq_key_or_503()
|
|
193
|
+
model = body.get("model") if isinstance(body.get("model"), str) else None
|
|
194
|
+
payload = prepare_for_groq(model, body)
|
|
195
|
+
payload = inject_system_hint(payload, system_hint or "")
|
|
196
|
+
_log_proxy("groq", payload)
|
|
197
|
+
data = await _call_vendor(proxy.groq_url.rstrip("/"), payload, {"Authorization": f"Bearer {key}"})
|
|
198
|
+
sanitized = synthesize_tool_call_if_missing(data)
|
|
199
|
+
return to_jsonable(sanitized)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _auth_dependency_factory(config: TaskAppConfig) -> Callable[[Request], None]:
|
|
203
|
+
def _dependency(request: Request) -> None:
|
|
204
|
+
if not config.require_api_key:
|
|
205
|
+
return
|
|
206
|
+
require_api_key_dependency(request)
|
|
207
|
+
|
|
208
|
+
return _dependency
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def create_task_app(config: TaskAppConfig) -> FastAPI:
|
|
212
|
+
cfg = config.clone()
|
|
213
|
+
app = FastAPI(title=cfg.name, description=cfg.description)
|
|
214
|
+
|
|
215
|
+
for key, value in cfg.app_state.items():
|
|
216
|
+
setattr(app.state, key, value)
|
|
217
|
+
|
|
218
|
+
if cfg.cors_origins is not None:
|
|
219
|
+
app.add_middleware(
|
|
220
|
+
CORSMiddleware,
|
|
221
|
+
allow_origins=list(cfg.cors_origins) or ["*"],
|
|
222
|
+
allow_credentials=True,
|
|
223
|
+
allow_methods=["*"],
|
|
224
|
+
allow_headers=["*"],
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Note: additional middleware from cfg.middleware is currently disabled to avoid typing ambiguity.
|
|
228
|
+
# for middleware in cfg.middleware:
|
|
229
|
+
# try:
|
|
230
|
+
# opts = getattr(middleware, "options", {})
|
|
231
|
+
# except Exception:
|
|
232
|
+
# opts = {}
|
|
233
|
+
# app.add_middleware(middleware.cls, **(opts if isinstance(opts, dict) else {}))
|
|
234
|
+
|
|
235
|
+
for router in cfg.routers:
|
|
236
|
+
try:
|
|
237
|
+
app.include_router(router)
|
|
238
|
+
except Exception:
|
|
239
|
+
try:
|
|
240
|
+
inner = getattr(router, "router", None)
|
|
241
|
+
if inner is not None:
|
|
242
|
+
app.include_router(inner)
|
|
243
|
+
except Exception:
|
|
244
|
+
raise
|
|
245
|
+
|
|
246
|
+
auth_dependency = _auth_dependency_factory(cfg)
|
|
247
|
+
|
|
248
|
+
def _call_hook(hook: Callable[..., Any]) -> Awaitable[Any]:
|
|
249
|
+
try:
|
|
250
|
+
params = inspect.signature(hook).parameters # type: ignore[arg-type]
|
|
251
|
+
except (TypeError, ValueError):
|
|
252
|
+
params = {}
|
|
253
|
+
if params:
|
|
254
|
+
return _maybe_await(hook(app)) # type: ignore[misc]
|
|
255
|
+
return _maybe_await(hook())
|
|
256
|
+
|
|
257
|
+
@app.on_event("startup")
|
|
258
|
+
async def _startup() -> None: # pragma: no cover - FastAPI lifecycle
|
|
259
|
+
normalize_environment_api_key()
|
|
260
|
+
normalize_vendor_keys()
|
|
261
|
+
for hook in cfg.startup_hooks:
|
|
262
|
+
await _call_hook(hook)
|
|
263
|
+
|
|
264
|
+
@app.on_event("shutdown")
|
|
265
|
+
async def _shutdown() -> None: # pragma: no cover - FastAPI lifecycle
|
|
266
|
+
for hook in cfg.shutdown_hooks:
|
|
267
|
+
await _call_hook(hook)
|
|
268
|
+
|
|
269
|
+
@app.get("/")
|
|
270
|
+
async def root() -> Mapping[str, Any]:
|
|
271
|
+
return to_jsonable({"status": "ok", "service": cfg.app_id})
|
|
272
|
+
|
|
273
|
+
@app.head("/")
|
|
274
|
+
async def root_head() -> Mapping[str, Any]:
|
|
275
|
+
return to_jsonable({"status": "ok"})
|
|
276
|
+
|
|
277
|
+
@app.get("/health", dependencies=[Depends(auth_dependency)])
|
|
278
|
+
async def health(request: Request) -> Mapping[str, Any]:
|
|
279
|
+
# If we got here, auth_dependency already verified the key exactly matches
|
|
280
|
+
expected = normalize_environment_api_key()
|
|
281
|
+
return to_jsonable({"healthy": True, "auth": {"required": True, "expected_prefix": (expected[:6] + '...') if expected else '<unset>'}})
|
|
282
|
+
|
|
283
|
+
@app.get("/info", dependencies=[Depends(auth_dependency)])
|
|
284
|
+
async def info() -> Mapping[str, Any]:
|
|
285
|
+
dataset_meta = cfg.base_task_info.dataset
|
|
286
|
+
rubrics: dict[str, Any] | None = None
|
|
287
|
+
if cfg.rubrics.outcome or cfg.rubrics.events:
|
|
288
|
+
rubrics = {
|
|
289
|
+
"outcome": cfg.rubrics.outcome.model_dump() if cfg.rubrics.outcome else None,
|
|
290
|
+
"events": cfg.rubrics.events.model_dump() if cfg.rubrics.events else None,
|
|
291
|
+
}
|
|
292
|
+
payload = {
|
|
293
|
+
"service": {
|
|
294
|
+
"task": cfg.base_task_info.task,
|
|
295
|
+
"version": cfg.base_task_info.task.get("version"),
|
|
296
|
+
},
|
|
297
|
+
"dataset": dataset_meta,
|
|
298
|
+
"rubrics": rubrics,
|
|
299
|
+
"inference": cfg.base_task_info.inference,
|
|
300
|
+
"capabilities": cfg.base_task_info.capabilities,
|
|
301
|
+
"limits": cfg.base_task_info.limits,
|
|
302
|
+
}
|
|
303
|
+
return to_jsonable(payload)
|
|
304
|
+
|
|
305
|
+
@app.get("/task_info", dependencies=[Depends(auth_dependency)])
|
|
306
|
+
async def task_info(
|
|
307
|
+
seed: Sequence[int] | None = Query(default=None),
|
|
308
|
+
seeds: Sequence[int] | None = Query(default=None),
|
|
309
|
+
) -> Any:
|
|
310
|
+
all_seeds: list[int] = []
|
|
311
|
+
if seed:
|
|
312
|
+
all_seeds.extend(_normalise_seeds(seed))
|
|
313
|
+
if seeds:
|
|
314
|
+
all_seeds.extend(_normalise_seeds(seeds))
|
|
315
|
+
|
|
316
|
+
if not all_seeds:
|
|
317
|
+
descriptor_result = await _maybe_await(cfg.describe_taskset())
|
|
318
|
+
return to_jsonable({"taskset": descriptor_result})
|
|
319
|
+
|
|
320
|
+
instances = await _maybe_await(cfg.provide_task_instances(all_seeds))
|
|
321
|
+
payload = [to_jsonable(_ensure_task_info(instance).model_dump()) for instance in instances]
|
|
322
|
+
return payload[0] if len(payload) == 1 else payload
|
|
323
|
+
|
|
324
|
+
@app.post("/rollout", dependencies=[Depends(auth_dependency)])
|
|
325
|
+
async def rollout_endpoint(rollout_request: RolloutRequest, request: Request) -> Any:
|
|
326
|
+
result = await _maybe_await(cfg.rollout(rollout_request, request))
|
|
327
|
+
if isinstance(result, RolloutResponse):
|
|
328
|
+
return to_jsonable(result.model_dump())
|
|
329
|
+
if isinstance(result, Mapping):
|
|
330
|
+
try:
|
|
331
|
+
validated = RolloutResponse.model_validate(result)
|
|
332
|
+
except Exception:
|
|
333
|
+
return to_jsonable(result)
|
|
334
|
+
return to_jsonable(validated.model_dump())
|
|
335
|
+
raise TypeError("Rollout executor must return RolloutResponse or mapping")
|
|
336
|
+
|
|
337
|
+
if cfg.expose_debug_env:
|
|
338
|
+
@app.get("/debug/env", dependencies=[Depends(auth_dependency)])
|
|
339
|
+
async def debug_env() -> Mapping[str, Any]:
|
|
340
|
+
def _mask(value: str | None) -> str:
|
|
341
|
+
if not value:
|
|
342
|
+
return ""
|
|
343
|
+
return f"{value[:6]}…" if len(value) > 6 else value
|
|
344
|
+
|
|
345
|
+
return to_jsonable(
|
|
346
|
+
{
|
|
347
|
+
"has_ENVIRONMENT_API_KEY": bool(os.getenv("ENVIRONMENT_API_KEY")),
|
|
348
|
+
"OPENAI_API_KEY_prefix": _mask(os.getenv("OPENAI_API_KEY")),
|
|
349
|
+
"GROQ_API_KEY_prefix": _mask(os.getenv("GROQ_API_KEY")),
|
|
350
|
+
}
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
_build_proxy_routes(app, cfg, auth_dependency)
|
|
354
|
+
|
|
355
|
+
return app
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _load_env_files(env_files: Sequence[str]) -> list[str]:
|
|
359
|
+
loaded: list[str] = []
|
|
360
|
+
if not env_files:
|
|
361
|
+
return loaded
|
|
362
|
+
try:
|
|
363
|
+
import dotenv
|
|
364
|
+
except Exception: # pragma: no cover - optional dep
|
|
365
|
+
return loaded
|
|
366
|
+
for path_str in env_files:
|
|
367
|
+
path = Path(path_str)
|
|
368
|
+
if not path.is_file():
|
|
369
|
+
continue
|
|
370
|
+
dotenv.load_dotenv(path, override=False)
|
|
371
|
+
loaded.append(str(path))
|
|
372
|
+
return loaded
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def run_task_app(
|
|
376
|
+
config_factory: Callable[[], TaskAppConfig],
|
|
377
|
+
*,
|
|
378
|
+
host: str = "0.0.0.0",
|
|
379
|
+
port: int = 8001,
|
|
380
|
+
reload: bool = False,
|
|
381
|
+
env_files: Sequence[str] = (),
|
|
382
|
+
) -> None:
|
|
383
|
+
"""Run the provided Task App configuration with uvicorn."""
|
|
384
|
+
|
|
385
|
+
loaded_files = _load_env_files(env_files)
|
|
386
|
+
if loaded_files:
|
|
387
|
+
print(f"[task:server] Loaded environment from: {', '.join(loaded_files)}", flush=True)
|
|
388
|
+
|
|
389
|
+
config = config_factory()
|
|
390
|
+
app = create_task_app(config)
|
|
391
|
+
|
|
392
|
+
try:
|
|
393
|
+
import uvicorn
|
|
394
|
+
except ImportError as exc: # pragma: no cover - uvicorn optional
|
|
395
|
+
raise RuntimeError("uvicorn must be installed to run the task app locally") from exc
|
|
396
|
+
|
|
397
|
+
print(f"[task:server] Starting '{config.app_id}' on {host}:{port}", flush=True)
|
|
398
|
+
uvicorn.run(app, host=host, port=port, reload=reload)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Utilities for wiring tracing_v3 into task apps."""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Callable
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def tracing_env_enabled(default: bool = False) -> bool:
|
|
12
|
+
"""Return True when tracing is enabled for task apps via environment variable."""
|
|
13
|
+
|
|
14
|
+
raw = os.getenv("TASKAPP_TRACING_ENABLED")
|
|
15
|
+
if raw is None:
|
|
16
|
+
return default
|
|
17
|
+
raw = raw.strip().lower()
|
|
18
|
+
if raw in {"1", "true", "t", "yes", "y", "on"}:
|
|
19
|
+
return True
|
|
20
|
+
if raw in {"0", "false", "f", "no", "n", "off"}:
|
|
21
|
+
return False
|
|
22
|
+
return default
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def resolve_tracing_db_url() -> str | None:
|
|
26
|
+
"""Resolve tracing database URL and prefer async drivers for SQLite."""
|
|
27
|
+
|
|
28
|
+
db_url = os.getenv("TURSO_LOCAL_DB_URL")
|
|
29
|
+
if db_url:
|
|
30
|
+
return db_url
|
|
31
|
+
|
|
32
|
+
sqld_path = os.getenv("SQLD_DB_PATH")
|
|
33
|
+
if sqld_path:
|
|
34
|
+
path = Path(sqld_path).expanduser()
|
|
35
|
+
if path.is_dir():
|
|
36
|
+
candidate = path / "dbs" / "default" / "data"
|
|
37
|
+
candidate.parent.mkdir(parents=True, exist_ok=True)
|
|
38
|
+
return f"sqlite+aiosqlite:///{candidate}"
|
|
39
|
+
else:
|
|
40
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
41
|
+
return f"sqlite+aiosqlite:///{path}"
|
|
42
|
+
|
|
43
|
+
fallback_path = Path("traces/v3/synth_ai.db").expanduser()
|
|
44
|
+
fallback_path.parent.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
return f"sqlite+aiosqlite:///{fallback_path}"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def build_tracer_factory(make_tracer: Callable[..., Any], *, enabled: bool, db_url: str | None) -> Callable[[], Any] | None:
|
|
49
|
+
"""Return a factory that instantiates a tracer when enabled, else None."""
|
|
50
|
+
|
|
51
|
+
if not enabled:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
def _factory() -> Any:
|
|
55
|
+
return make_tracer(db_url=db_url) if db_url else make_tracer()
|
|
56
|
+
|
|
57
|
+
return _factory
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def resolve_sft_output_dir() -> str | None:
|
|
61
|
+
"""Resolve location for writing SFT records, creating directory if requested."""
|
|
62
|
+
|
|
63
|
+
raw = os.getenv("TASKAPP_SFT_OUTPUT_DIR") or os.getenv("SFT_OUTPUT_DIR")
|
|
64
|
+
if not raw:
|
|
65
|
+
return None
|
|
66
|
+
path = Path(raw).expanduser()
|
|
67
|
+
try:
|
|
68
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
69
|
+
except Exception:
|
|
70
|
+
return None
|
|
71
|
+
return str(path)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def unique_sft_path(base_dir: str, *, run_id: str) -> Path:
|
|
75
|
+
"""Return a unique JSONL path for an SFT record batch."""
|
|
76
|
+
|
|
77
|
+
ts = int(time.time() * 1000)
|
|
78
|
+
name = f"{run_id}_{ts}.jsonl"
|
|
79
|
+
return Path(base_dir) / name
|
synth_ai/task/vendors.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Vendor API key helpers shared by Task Apps."""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from .errors import http_exception
|
|
9
|
+
|
|
10
|
+
_VENDOR_KEYS = {
|
|
11
|
+
"OPENAI_API_KEY": ("dev_openai_api_key", "DEV_OPENAI_API_KEY"),
|
|
12
|
+
"GROQ_API_KEY": ("dev_groq_api_key", "DEV_GROQ_API_KEY"),
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _mask(value: str, *, prefix: int = 4) -> str:
|
|
17
|
+
if not value:
|
|
18
|
+
return "<empty>"
|
|
19
|
+
visible = value[:prefix]
|
|
20
|
+
return f"{visible}{'…' if len(value) > prefix else ''}"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _normalize_single(key: str) -> Optional[str]:
|
|
24
|
+
direct = os.getenv(key)
|
|
25
|
+
if direct:
|
|
26
|
+
return direct
|
|
27
|
+
fallbacks = _VENDOR_KEYS.get(key, ())
|
|
28
|
+
for env in fallbacks:
|
|
29
|
+
candidate = os.getenv(env)
|
|
30
|
+
if candidate:
|
|
31
|
+
os.environ[key] = candidate
|
|
32
|
+
print(
|
|
33
|
+
f"[task:vendor] {key} set from {env} (prefix={_mask(candidate)})",
|
|
34
|
+
flush=True,
|
|
35
|
+
)
|
|
36
|
+
return candidate
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def normalize_vendor_keys() -> dict[str, Optional[str]]:
|
|
41
|
+
"""Normalise known vendor keys from dev fallbacks and return the mapping."""
|
|
42
|
+
|
|
43
|
+
resolved: dict[str, Optional[str]] = {}
|
|
44
|
+
for key in _VENDOR_KEYS:
|
|
45
|
+
resolved[key] = _normalize_single(key)
|
|
46
|
+
return resolved
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_openai_key_or_503() -> str:
|
|
50
|
+
key = _normalize_single("OPENAI_API_KEY")
|
|
51
|
+
if not key:
|
|
52
|
+
raise http_exception(503, "missing_openai_api_key", "OPENAI_API_KEY is not configured")
|
|
53
|
+
return key
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_groq_key_or_503() -> str:
|
|
57
|
+
key = _normalize_single("GROQ_API_KEY")
|
|
58
|
+
if not key:
|
|
59
|
+
raise http_exception(503, "missing_groq_api_key", "GROQ_API_KEY is not configured")
|
|
60
|
+
return key
|
|
61
|
+
|
|
@@ -234,7 +234,7 @@ class SessionTracer:
|
|
|
234
234
|
event_id = await self.db.insert_event_row(
|
|
235
235
|
self._current_trace.session_id,
|
|
236
236
|
timestep_db_id=timestep_db_id,
|
|
237
|
-
event=event,
|
|
237
|
+
event=event, # type: ignore[arg-type]
|
|
238
238
|
)
|
|
239
239
|
# Auto-insert an event reward if EnvironmentEvent carries reward
|
|
240
240
|
try:
|
|
@@ -323,7 +323,7 @@ class SessionTracer:
|
|
|
323
323
|
return message_id
|
|
324
324
|
return None
|
|
325
325
|
|
|
326
|
-
async def end_session(self, save: bool = None) -> SessionTrace:
|
|
326
|
+
async def end_session(self, save: bool | None = None) -> SessionTrace:
|
|
327
327
|
"""End the current session.
|
|
328
328
|
|
|
329
329
|
Args:
|
|
@@ -370,7 +370,7 @@ class SessionTracer:
|
|
|
370
370
|
self,
|
|
371
371
|
session_id: str | None = None,
|
|
372
372
|
metadata: dict[str, Any] | None = None,
|
|
373
|
-
save: bool = None,
|
|
373
|
+
save: bool | None = None,
|
|
374
374
|
):
|
|
375
375
|
"""Context manager for a session.
|
|
376
376
|
|
|
@@ -414,8 +414,16 @@ class SessionTracer:
|
|
|
414
414
|
if limit:
|
|
415
415
|
query += f" LIMIT {limit}"
|
|
416
416
|
|
|
417
|
-
|
|
418
|
-
|
|
417
|
+
# Ensure DB initialized before querying
|
|
418
|
+
if self.db is None:
|
|
419
|
+
await self.initialize()
|
|
420
|
+
df_or_records = await self.db.query_traces(query) # type: ignore[union-attr]
|
|
421
|
+
try:
|
|
422
|
+
# If pandas DataFrame
|
|
423
|
+
return df_or_records.to_dict("records") # type: ignore[call-arg, attr-defined]
|
|
424
|
+
except AttributeError:
|
|
425
|
+
# Already list of dicts
|
|
426
|
+
return df_or_records
|
|
419
427
|
|
|
420
428
|
async def close(self):
|
|
421
429
|
"""Close database connections."""
|
|
@@ -4,8 +4,6 @@ from abc import ABC, abstractmethod
|
|
|
4
4
|
from datetime import datetime
|
|
5
5
|
from typing import Any
|
|
6
6
|
|
|
7
|
-
import pandas as pd
|
|
8
|
-
|
|
9
7
|
from ..abstractions import SessionTrace
|
|
10
8
|
|
|
11
9
|
|
|
@@ -42,22 +40,22 @@ class TraceStorage(ABC):
|
|
|
42
40
|
pass
|
|
43
41
|
|
|
44
42
|
@abstractmethod
|
|
45
|
-
async def query_traces(self, query: str, params: dict[str, Any] = None) ->
|
|
46
|
-
"""Execute a query and return results
|
|
43
|
+
async def query_traces(self, query: str, params: dict[str, Any] | None = None) -> Any:
|
|
44
|
+
"""Execute a query and return results.
|
|
47
45
|
|
|
48
46
|
Args:
|
|
49
47
|
query: The SQL query to execute
|
|
50
48
|
params: Optional query parameters
|
|
51
49
|
|
|
52
50
|
Returns:
|
|
53
|
-
Query results as a DataFrame
|
|
51
|
+
Query results as a DataFrame-like object or list of dict records
|
|
54
52
|
"""
|
|
55
53
|
pass
|
|
56
54
|
|
|
57
55
|
@abstractmethod
|
|
58
56
|
async def get_model_usage(
|
|
59
|
-
self, start_date: datetime = None, end_date: datetime = None, model_name: str = None
|
|
60
|
-
) ->
|
|
57
|
+
self, start_date: datetime | None = None, end_date: datetime | None = None, model_name: str | None = None
|
|
58
|
+
) -> Any:
|
|
61
59
|
"""Get model usage statistics.
|
|
62
60
|
|
|
63
61
|
Args:
|
|
@@ -66,7 +64,7 @@ class TraceStorage(ABC):
|
|
|
66
64
|
model_name: Optional model name filter
|
|
67
65
|
|
|
68
66
|
Returns:
|
|
69
|
-
Model usage statistics as a DataFrame
|
|
67
|
+
Model usage statistics as a DataFrame-like object or list of dict records
|
|
70
68
|
"""
|
|
71
69
|
pass
|
|
72
70
|
|
|
@@ -92,8 +90,8 @@ class TraceStorage(ABC):
|
|
|
92
90
|
self,
|
|
93
91
|
experiment_id: str,
|
|
94
92
|
name: str,
|
|
95
|
-
description: str = None,
|
|
96
|
-
configuration: dict[str, Any] = None,
|
|
93
|
+
description: str | None = None,
|
|
94
|
+
configuration: dict[str, Any] | None = None,
|
|
97
95
|
) -> str:
|
|
98
96
|
"""Create a new experiment."""
|
|
99
97
|
raise NotImplementedError("Experiment management not supported by this backend")
|
|
@@ -103,14 +101,14 @@ class TraceStorage(ABC):
|
|
|
103
101
|
raise NotImplementedError("Experiment management not supported by this backend")
|
|
104
102
|
|
|
105
103
|
async def get_sessions_by_experiment(
|
|
106
|
-
self, experiment_id: str, limit: int = None
|
|
104
|
+
self, experiment_id: str, limit: int | None = None
|
|
107
105
|
) -> list[dict[str, Any]]:
|
|
108
106
|
"""Get all sessions for an experiment."""
|
|
109
107
|
raise NotImplementedError("Experiment management not supported by this backend")
|
|
110
108
|
|
|
111
109
|
# Batch operations
|
|
112
110
|
async def batch_insert_sessions(
|
|
113
|
-
self, traces: list[SessionTrace], batch_size: int = 1000
|
|
111
|
+
self, traces: list[SessionTrace], batch_size: int | None = 1000
|
|
114
112
|
) -> list[str]:
|
|
115
113
|
"""Batch insert multiple session traces.
|
|
116
114
|
|