synth-ai 0.2.8.dev11__py3-none-any.whl → 0.2.8.dev13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (37) hide show
  1. synth_ai/api/train/__init__.py +5 -0
  2. synth_ai/api/train/builders.py +165 -0
  3. synth_ai/api/train/cli.py +429 -0
  4. synth_ai/api/train/config_finder.py +120 -0
  5. synth_ai/api/train/env_resolver.py +302 -0
  6. synth_ai/api/train/pollers.py +66 -0
  7. synth_ai/api/train/task_app.py +128 -0
  8. synth_ai/api/train/utils.py +232 -0
  9. synth_ai/cli/__init__.py +23 -0
  10. synth_ai/cli/rl_demo.py +2 -2
  11. synth_ai/cli/root.py +2 -1
  12. synth_ai/cli/task_apps.py +520 -0
  13. synth_ai/demos/demo_task_apps/math/modal_task_app.py +31 -25
  14. synth_ai/task/__init__.py +94 -1
  15. synth_ai/task/apps/__init__.py +88 -0
  16. synth_ai/task/apps/grpo_crafter.py +438 -0
  17. synth_ai/task/apps/math_single_step.py +852 -0
  18. synth_ai/task/auth.py +132 -0
  19. synth_ai/task/client.py +148 -0
  20. synth_ai/task/contracts.py +29 -14
  21. synth_ai/task/datasets.py +105 -0
  22. synth_ai/task/errors.py +49 -0
  23. synth_ai/task/json.py +77 -0
  24. synth_ai/task/proxy.py +258 -0
  25. synth_ai/task/rubrics.py +212 -0
  26. synth_ai/task/server.py +398 -0
  27. synth_ai/task/tracing_utils.py +79 -0
  28. synth_ai/task/vendors.py +61 -0
  29. synth_ai/tracing_v3/session_tracer.py +13 -5
  30. synth_ai/tracing_v3/storage/base.py +10 -12
  31. synth_ai/tracing_v3/turso/manager.py +20 -6
  32. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/METADATA +3 -2
  33. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/RECORD +37 -15
  34. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/WHEEL +0 -0
  35. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/entry_points.txt +0 -0
  36. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/licenses/LICENSE +0 -0
  37. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,398 @@
1
+ from __future__ import annotations
2
+
3
+ """FastAPI scaffolding for Task Apps (local dev + deployment)."""
4
+
5
+ import asyncio
6
+ import inspect
7
+ import os
8
+ from dataclasses import dataclass, field
9
+ from pathlib import Path
10
+ from typing import Any, Awaitable, Callable, Iterable, Mapping, MutableMapping, Sequence
11
+
12
+ import httpx
13
+ from fastapi import APIRouter, Depends, FastAPI, Query, Request
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.responses import JSONResponse
16
+ from starlette.middleware import Middleware
17
+
18
+ from .auth import (
19
+ is_api_key_header_authorized,
20
+ normalize_environment_api_key,
21
+ require_api_key_dependency,
22
+ )
23
+ from .contracts import RolloutRequest, RolloutResponse, TaskInfo
24
+ from .datasets import TaskDatasetRegistry
25
+ from .errors import http_exception
26
+ from .json import to_jsonable
27
+ from .proxy import (
28
+ prepare_for_groq,
29
+ prepare_for_openai,
30
+ inject_system_hint,
31
+ synthesize_tool_call_if_missing,
32
+ )
33
+ from .rubrics import Rubric
34
+ from .vendors import get_groq_key_or_503, get_openai_key_or_503, normalize_vendor_keys
35
+
36
+
37
+ TasksetDescriptor = Callable[[], Mapping[str, Any] | Awaitable[Mapping[str, Any]]]
38
+ InstanceProvider = Callable[[Sequence[int]], Iterable[TaskInfo] | Awaitable[Iterable[TaskInfo]]]
39
+ RolloutExecutor = Callable[[RolloutRequest, Request], Any | Awaitable[Any]]
40
+
41
+
42
+ @dataclass(slots=True)
43
+ class RubricBundle:
44
+ """Optional rubrics advertised by the task app."""
45
+
46
+ outcome: Rubric | None = None
47
+ events: Rubric | None = None
48
+
49
+
50
+ @dataclass(slots=True)
51
+ class ProxyConfig:
52
+ """Configuration for optional vendor proxy endpoints."""
53
+
54
+ enable_openai: bool = False
55
+ enable_groq: bool = False
56
+ system_hint: str | None = None
57
+ openai_url: str = "https://api.openai.com/v1/chat/completions"
58
+ groq_url: str = "https://api.groq.com/openai/v1/chat/completions"
59
+
60
+
61
+ @dataclass(slots=True)
62
+ class TaskAppConfig:
63
+ """Declarative configuration describing a Task App."""
64
+
65
+ app_id: str
66
+ name: str
67
+ description: str
68
+ base_task_info: TaskInfo
69
+ describe_taskset: TasksetDescriptor
70
+ provide_task_instances: InstanceProvider
71
+ rollout: RolloutExecutor
72
+ dataset_registry: TaskDatasetRegistry | None = None
73
+ rubrics: RubricBundle = field(default_factory=RubricBundle)
74
+ proxy: ProxyConfig | None = None
75
+ routers: Sequence[APIRouter] = field(default_factory=tuple)
76
+ middleware: Sequence[Middleware] = field(default_factory=tuple)
77
+ app_state: Mapping[str, Any] = field(default_factory=dict)
78
+ require_api_key: bool = True
79
+ expose_debug_env: bool = True
80
+ cors_origins: Sequence[str] | None = None
81
+ startup_hooks: Sequence[Callable[[], None | Awaitable[None]]] = field(default_factory=tuple)
82
+ shutdown_hooks: Sequence[Callable[[], None | Awaitable[None]]] = field(default_factory=tuple)
83
+
84
+ def clone(self) -> "TaskAppConfig":
85
+ """Return a shallow copy safe to mutate when wiring the app."""
86
+
87
+ return TaskAppConfig(
88
+ app_id=self.app_id,
89
+ name=self.name,
90
+ description=self.description,
91
+ base_task_info=self.base_task_info,
92
+ describe_taskset=self.describe_taskset,
93
+ provide_task_instances=self.provide_task_instances,
94
+ rollout=self.rollout,
95
+ dataset_registry=self.dataset_registry,
96
+ rubrics=self.rubrics,
97
+ proxy=self.proxy,
98
+ routers=tuple(self.routers),
99
+ middleware=tuple(self.middleware),
100
+ app_state=dict(self.app_state),
101
+ require_api_key=self.require_api_key,
102
+ expose_debug_env=self.expose_debug_env,
103
+ cors_origins=tuple(self.cors_origins or ()),
104
+ startup_hooks=tuple(self.startup_hooks),
105
+ shutdown_hooks=tuple(self.shutdown_hooks),
106
+ )
107
+
108
+
109
+ def _maybe_await(result: Any) -> Awaitable[Any]:
110
+ if inspect.isawaitable(result):
111
+ return asyncio.ensure_future(result)
112
+ loop = asyncio.get_event_loop()
113
+ future: asyncio.Future[Any] = loop.create_future()
114
+ future.set_result(result)
115
+ return future
116
+
117
+
118
+ def _ensure_task_info(obj: Any) -> TaskInfo:
119
+ if isinstance(obj, TaskInfo):
120
+ return obj
121
+ if isinstance(obj, MutableMapping):
122
+ return TaskInfo.model_validate(obj)
123
+ raise TypeError(f"Task instance provider must yield TaskInfo-compatible objects (got {type(obj)!r})")
124
+
125
+
126
+ def _normalise_seeds(values: Sequence[int]) -> list[int]:
127
+ seeds: list[int] = []
128
+ for value in values:
129
+ try:
130
+ seeds.append(int(value))
131
+ except Exception as exc: # pragma: no cover - defensive
132
+ raise ValueError(f"Seed values must be convertible to int (got {value!r})") from exc
133
+ return seeds
134
+
135
+
136
+ def _build_proxy_routes(
137
+ app: FastAPI, config: TaskAppConfig, auth_dependency: Callable[[Request], None]
138
+ ) -> None:
139
+ proxy = config.proxy
140
+ if not proxy:
141
+ return
142
+
143
+ async def _call_vendor(url: str, payload: dict[str, Any], headers: dict[str, str]) -> dict[str, Any]:
144
+ async with httpx.AsyncClient(timeout=httpx.Timeout(600.0), follow_redirects=True) as client:
145
+ response = await client.post(url, json=payload, headers=headers)
146
+ data = (
147
+ response.json()
148
+ if response.headers.get("content-type", "").startswith("application/json")
149
+ else {"raw": response.text}
150
+ )
151
+ if response.status_code >= 400:
152
+ code = "vendor_error"
153
+ if url.startswith("https://api.openai.com"):
154
+ code = "openai_error"
155
+ elif "groq" in url:
156
+ code = "groq_error"
157
+ raise http_exception(
158
+ response.status_code,
159
+ code,
160
+ "Vendor proxy error",
161
+ extra={"status": response.status_code, "body": data},
162
+ )
163
+ return data
164
+
165
+ def _log_proxy(route: str, payload: dict[str, Any]) -> None:
166
+ try:
167
+ messages = payload.get("messages") if isinstance(payload, dict) else None
168
+ msg_count = len(messages) if isinstance(messages, list) else 0
169
+ tool_count = len(payload.get("tools") or []) if isinstance(payload, dict) else 0
170
+ model = payload.get("model") if isinstance(payload, dict) else None
171
+ print(f"[task:proxy:{route}] model={model} messages={msg_count} tools={tool_count}", flush=True)
172
+ except Exception: # pragma: no cover - best effort logging
173
+ pass
174
+
175
+ system_hint = proxy.system_hint
176
+
177
+ if proxy.enable_openai:
178
+ @app.post("/proxy/v1/chat/completions", dependencies=[Depends(auth_dependency)])
179
+ async def proxy_openai(body: dict[str, Any], request: Request) -> Any: # type: ignore[no-redef]
180
+ key = get_openai_key_or_503()
181
+ model = body.get("model") if isinstance(body.get("model"), str) else None
182
+ payload = prepare_for_openai(model, body)
183
+ payload = inject_system_hint(payload, system_hint or "")
184
+ _log_proxy("openai", payload)
185
+ data = await _call_vendor(proxy.openai_url, payload, {"Authorization": f"Bearer {key}"})
186
+ sanitized = synthesize_tool_call_if_missing(data)
187
+ return to_jsonable(sanitized)
188
+
189
+ if proxy.enable_groq:
190
+ @app.post("/proxy/groq/v1/chat/completions", dependencies=[Depends(auth_dependency)])
191
+ async def proxy_groq(body: dict[str, Any], request: Request) -> Any: # type: ignore[no-redef]
192
+ key = get_groq_key_or_503()
193
+ model = body.get("model") if isinstance(body.get("model"), str) else None
194
+ payload = prepare_for_groq(model, body)
195
+ payload = inject_system_hint(payload, system_hint or "")
196
+ _log_proxy("groq", payload)
197
+ data = await _call_vendor(proxy.groq_url.rstrip("/"), payload, {"Authorization": f"Bearer {key}"})
198
+ sanitized = synthesize_tool_call_if_missing(data)
199
+ return to_jsonable(sanitized)
200
+
201
+
202
+ def _auth_dependency_factory(config: TaskAppConfig) -> Callable[[Request], None]:
203
+ def _dependency(request: Request) -> None:
204
+ if not config.require_api_key:
205
+ return
206
+ require_api_key_dependency(request)
207
+
208
+ return _dependency
209
+
210
+
211
+ def create_task_app(config: TaskAppConfig) -> FastAPI:
212
+ cfg = config.clone()
213
+ app = FastAPI(title=cfg.name, description=cfg.description)
214
+
215
+ for key, value in cfg.app_state.items():
216
+ setattr(app.state, key, value)
217
+
218
+ if cfg.cors_origins is not None:
219
+ app.add_middleware(
220
+ CORSMiddleware,
221
+ allow_origins=list(cfg.cors_origins) or ["*"],
222
+ allow_credentials=True,
223
+ allow_methods=["*"],
224
+ allow_headers=["*"],
225
+ )
226
+
227
+ # Note: additional middleware from cfg.middleware is currently disabled to avoid typing ambiguity.
228
+ # for middleware in cfg.middleware:
229
+ # try:
230
+ # opts = getattr(middleware, "options", {})
231
+ # except Exception:
232
+ # opts = {}
233
+ # app.add_middleware(middleware.cls, **(opts if isinstance(opts, dict) else {}))
234
+
235
+ for router in cfg.routers:
236
+ try:
237
+ app.include_router(router)
238
+ except Exception:
239
+ try:
240
+ inner = getattr(router, "router", None)
241
+ if inner is not None:
242
+ app.include_router(inner)
243
+ except Exception:
244
+ raise
245
+
246
+ auth_dependency = _auth_dependency_factory(cfg)
247
+
248
+ def _call_hook(hook: Callable[..., Any]) -> Awaitable[Any]:
249
+ try:
250
+ params = inspect.signature(hook).parameters # type: ignore[arg-type]
251
+ except (TypeError, ValueError):
252
+ params = {}
253
+ if params:
254
+ return _maybe_await(hook(app)) # type: ignore[misc]
255
+ return _maybe_await(hook())
256
+
257
+ @app.on_event("startup")
258
+ async def _startup() -> None: # pragma: no cover - FastAPI lifecycle
259
+ normalize_environment_api_key()
260
+ normalize_vendor_keys()
261
+ for hook in cfg.startup_hooks:
262
+ await _call_hook(hook)
263
+
264
+ @app.on_event("shutdown")
265
+ async def _shutdown() -> None: # pragma: no cover - FastAPI lifecycle
266
+ for hook in cfg.shutdown_hooks:
267
+ await _call_hook(hook)
268
+
269
+ @app.get("/")
270
+ async def root() -> Mapping[str, Any]:
271
+ return to_jsonable({"status": "ok", "service": cfg.app_id})
272
+
273
+ @app.head("/")
274
+ async def root_head() -> Mapping[str, Any]:
275
+ return to_jsonable({"status": "ok"})
276
+
277
+ @app.get("/health", dependencies=[Depends(auth_dependency)])
278
+ async def health(request: Request) -> Mapping[str, Any]:
279
+ # If we got here, auth_dependency already verified the key exactly matches
280
+ expected = normalize_environment_api_key()
281
+ return to_jsonable({"healthy": True, "auth": {"required": True, "expected_prefix": (expected[:6] + '...') if expected else '<unset>'}})
282
+
283
+ @app.get("/info", dependencies=[Depends(auth_dependency)])
284
+ async def info() -> Mapping[str, Any]:
285
+ dataset_meta = cfg.base_task_info.dataset
286
+ rubrics: dict[str, Any] | None = None
287
+ if cfg.rubrics.outcome or cfg.rubrics.events:
288
+ rubrics = {
289
+ "outcome": cfg.rubrics.outcome.model_dump() if cfg.rubrics.outcome else None,
290
+ "events": cfg.rubrics.events.model_dump() if cfg.rubrics.events else None,
291
+ }
292
+ payload = {
293
+ "service": {
294
+ "task": cfg.base_task_info.task,
295
+ "version": cfg.base_task_info.task.get("version"),
296
+ },
297
+ "dataset": dataset_meta,
298
+ "rubrics": rubrics,
299
+ "inference": cfg.base_task_info.inference,
300
+ "capabilities": cfg.base_task_info.capabilities,
301
+ "limits": cfg.base_task_info.limits,
302
+ }
303
+ return to_jsonable(payload)
304
+
305
+ @app.get("/task_info", dependencies=[Depends(auth_dependency)])
306
+ async def task_info(
307
+ seed: Sequence[int] | None = Query(default=None),
308
+ seeds: Sequence[int] | None = Query(default=None),
309
+ ) -> Any:
310
+ all_seeds: list[int] = []
311
+ if seed:
312
+ all_seeds.extend(_normalise_seeds(seed))
313
+ if seeds:
314
+ all_seeds.extend(_normalise_seeds(seeds))
315
+
316
+ if not all_seeds:
317
+ descriptor_result = await _maybe_await(cfg.describe_taskset())
318
+ return to_jsonable({"taskset": descriptor_result})
319
+
320
+ instances = await _maybe_await(cfg.provide_task_instances(all_seeds))
321
+ payload = [to_jsonable(_ensure_task_info(instance).model_dump()) for instance in instances]
322
+ return payload[0] if len(payload) == 1 else payload
323
+
324
+ @app.post("/rollout", dependencies=[Depends(auth_dependency)])
325
+ async def rollout_endpoint(rollout_request: RolloutRequest, request: Request) -> Any:
326
+ result = await _maybe_await(cfg.rollout(rollout_request, request))
327
+ if isinstance(result, RolloutResponse):
328
+ return to_jsonable(result.model_dump())
329
+ if isinstance(result, Mapping):
330
+ try:
331
+ validated = RolloutResponse.model_validate(result)
332
+ except Exception:
333
+ return to_jsonable(result)
334
+ return to_jsonable(validated.model_dump())
335
+ raise TypeError("Rollout executor must return RolloutResponse or mapping")
336
+
337
+ if cfg.expose_debug_env:
338
+ @app.get("/debug/env", dependencies=[Depends(auth_dependency)])
339
+ async def debug_env() -> Mapping[str, Any]:
340
+ def _mask(value: str | None) -> str:
341
+ if not value:
342
+ return ""
343
+ return f"{value[:6]}…" if len(value) > 6 else value
344
+
345
+ return to_jsonable(
346
+ {
347
+ "has_ENVIRONMENT_API_KEY": bool(os.getenv("ENVIRONMENT_API_KEY")),
348
+ "OPENAI_API_KEY_prefix": _mask(os.getenv("OPENAI_API_KEY")),
349
+ "GROQ_API_KEY_prefix": _mask(os.getenv("GROQ_API_KEY")),
350
+ }
351
+ )
352
+
353
+ _build_proxy_routes(app, cfg, auth_dependency)
354
+
355
+ return app
356
+
357
+
358
+ def _load_env_files(env_files: Sequence[str]) -> list[str]:
359
+ loaded: list[str] = []
360
+ if not env_files:
361
+ return loaded
362
+ try:
363
+ import dotenv
364
+ except Exception: # pragma: no cover - optional dep
365
+ return loaded
366
+ for path_str in env_files:
367
+ path = Path(path_str)
368
+ if not path.is_file():
369
+ continue
370
+ dotenv.load_dotenv(path, override=False)
371
+ loaded.append(str(path))
372
+ return loaded
373
+
374
+
375
+ def run_task_app(
376
+ config_factory: Callable[[], TaskAppConfig],
377
+ *,
378
+ host: str = "0.0.0.0",
379
+ port: int = 8001,
380
+ reload: bool = False,
381
+ env_files: Sequence[str] = (),
382
+ ) -> None:
383
+ """Run the provided Task App configuration with uvicorn."""
384
+
385
+ loaded_files = _load_env_files(env_files)
386
+ if loaded_files:
387
+ print(f"[task:server] Loaded environment from: {', '.join(loaded_files)}", flush=True)
388
+
389
+ config = config_factory()
390
+ app = create_task_app(config)
391
+
392
+ try:
393
+ import uvicorn
394
+ except ImportError as exc: # pragma: no cover - uvicorn optional
395
+ raise RuntimeError("uvicorn must be installed to run the task app locally") from exc
396
+
397
+ print(f"[task:server] Starting '{config.app_id}' on {host}:{port}", flush=True)
398
+ uvicorn.run(app, host=host, port=port, reload=reload)
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+
3
+ """Utilities for wiring tracing_v3 into task apps."""
4
+
5
+ import os
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Any, Callable
9
+
10
+
11
+ def tracing_env_enabled(default: bool = False) -> bool:
12
+ """Return True when tracing is enabled for task apps via environment variable."""
13
+
14
+ raw = os.getenv("TASKAPP_TRACING_ENABLED")
15
+ if raw is None:
16
+ return default
17
+ raw = raw.strip().lower()
18
+ if raw in {"1", "true", "t", "yes", "y", "on"}:
19
+ return True
20
+ if raw in {"0", "false", "f", "no", "n", "off"}:
21
+ return False
22
+ return default
23
+
24
+
25
+ def resolve_tracing_db_url() -> str | None:
26
+ """Resolve tracing database URL and prefer async drivers for SQLite."""
27
+
28
+ db_url = os.getenv("TURSO_LOCAL_DB_URL")
29
+ if db_url:
30
+ return db_url
31
+
32
+ sqld_path = os.getenv("SQLD_DB_PATH")
33
+ if sqld_path:
34
+ path = Path(sqld_path).expanduser()
35
+ if path.is_dir():
36
+ candidate = path / "dbs" / "default" / "data"
37
+ candidate.parent.mkdir(parents=True, exist_ok=True)
38
+ return f"sqlite+aiosqlite:///{candidate}"
39
+ else:
40
+ path.parent.mkdir(parents=True, exist_ok=True)
41
+ return f"sqlite+aiosqlite:///{path}"
42
+
43
+ fallback_path = Path("traces/v3/synth_ai.db").expanduser()
44
+ fallback_path.parent.mkdir(parents=True, exist_ok=True)
45
+ return f"sqlite+aiosqlite:///{fallback_path}"
46
+
47
+
48
+ def build_tracer_factory(make_tracer: Callable[..., Any], *, enabled: bool, db_url: str | None) -> Callable[[], Any] | None:
49
+ """Return a factory that instantiates a tracer when enabled, else None."""
50
+
51
+ if not enabled:
52
+ return None
53
+
54
+ def _factory() -> Any:
55
+ return make_tracer(db_url=db_url) if db_url else make_tracer()
56
+
57
+ return _factory
58
+
59
+
60
+ def resolve_sft_output_dir() -> str | None:
61
+ """Resolve location for writing SFT records, creating directory if requested."""
62
+
63
+ raw = os.getenv("TASKAPP_SFT_OUTPUT_DIR") or os.getenv("SFT_OUTPUT_DIR")
64
+ if not raw:
65
+ return None
66
+ path = Path(raw).expanduser()
67
+ try:
68
+ path.mkdir(parents=True, exist_ok=True)
69
+ except Exception:
70
+ return None
71
+ return str(path)
72
+
73
+
74
+ def unique_sft_path(base_dir: str, *, run_id: str) -> Path:
75
+ """Return a unique JSONL path for an SFT record batch."""
76
+
77
+ ts = int(time.time() * 1000)
78
+ name = f"{run_id}_{ts}.jsonl"
79
+ return Path(base_dir) / name
@@ -0,0 +1,61 @@
1
+ from __future__ import annotations
2
+
3
+ """Vendor API key helpers shared by Task Apps."""
4
+
5
+ import os
6
+ from typing import Optional
7
+
8
+ from .errors import http_exception
9
+
10
+ _VENDOR_KEYS = {
11
+ "OPENAI_API_KEY": ("dev_openai_api_key", "DEV_OPENAI_API_KEY"),
12
+ "GROQ_API_KEY": ("dev_groq_api_key", "DEV_GROQ_API_KEY"),
13
+ }
14
+
15
+
16
+ def _mask(value: str, *, prefix: int = 4) -> str:
17
+ if not value:
18
+ return "<empty>"
19
+ visible = value[:prefix]
20
+ return f"{visible}{'…' if len(value) > prefix else ''}"
21
+
22
+
23
+ def _normalize_single(key: str) -> Optional[str]:
24
+ direct = os.getenv(key)
25
+ if direct:
26
+ return direct
27
+ fallbacks = _VENDOR_KEYS.get(key, ())
28
+ for env in fallbacks:
29
+ candidate = os.getenv(env)
30
+ if candidate:
31
+ os.environ[key] = candidate
32
+ print(
33
+ f"[task:vendor] {key} set from {env} (prefix={_mask(candidate)})",
34
+ flush=True,
35
+ )
36
+ return candidate
37
+ return None
38
+
39
+
40
+ def normalize_vendor_keys() -> dict[str, Optional[str]]:
41
+ """Normalise known vendor keys from dev fallbacks and return the mapping."""
42
+
43
+ resolved: dict[str, Optional[str]] = {}
44
+ for key in _VENDOR_KEYS:
45
+ resolved[key] = _normalize_single(key)
46
+ return resolved
47
+
48
+
49
+ def get_openai_key_or_503() -> str:
50
+ key = _normalize_single("OPENAI_API_KEY")
51
+ if not key:
52
+ raise http_exception(503, "missing_openai_api_key", "OPENAI_API_KEY is not configured")
53
+ return key
54
+
55
+
56
+ def get_groq_key_or_503() -> str:
57
+ key = _normalize_single("GROQ_API_KEY")
58
+ if not key:
59
+ raise http_exception(503, "missing_groq_api_key", "GROQ_API_KEY is not configured")
60
+ return key
61
+
@@ -234,7 +234,7 @@ class SessionTracer:
234
234
  event_id = await self.db.insert_event_row(
235
235
  self._current_trace.session_id,
236
236
  timestep_db_id=timestep_db_id,
237
- event=event,
237
+ event=event, # type: ignore[arg-type]
238
238
  )
239
239
  # Auto-insert an event reward if EnvironmentEvent carries reward
240
240
  try:
@@ -323,7 +323,7 @@ class SessionTracer:
323
323
  return message_id
324
324
  return None
325
325
 
326
- async def end_session(self, save: bool = None) -> SessionTrace:
326
+ async def end_session(self, save: bool | None = None) -> SessionTrace:
327
327
  """End the current session.
328
328
 
329
329
  Args:
@@ -370,7 +370,7 @@ class SessionTracer:
370
370
  self,
371
371
  session_id: str | None = None,
372
372
  metadata: dict[str, Any] | None = None,
373
- save: bool = None,
373
+ save: bool | None = None,
374
374
  ):
375
375
  """Context manager for a session.
376
376
 
@@ -414,8 +414,16 @@ class SessionTracer:
414
414
  if limit:
415
415
  query += f" LIMIT {limit}"
416
416
 
417
- df = await self.db.query_traces(query)
418
- return df.to_dict("records")
417
+ # Ensure DB initialized before querying
418
+ if self.db is None:
419
+ await self.initialize()
420
+ df_or_records = await self.db.query_traces(query) # type: ignore[union-attr]
421
+ try:
422
+ # If pandas DataFrame
423
+ return df_or_records.to_dict("records") # type: ignore[call-arg, attr-defined]
424
+ except AttributeError:
425
+ # Already list of dicts
426
+ return df_or_records
419
427
 
420
428
  async def close(self):
421
429
  """Close database connections."""
@@ -4,8 +4,6 @@ from abc import ABC, abstractmethod
4
4
  from datetime import datetime
5
5
  from typing import Any
6
6
 
7
- import pandas as pd
8
-
9
7
  from ..abstractions import SessionTrace
10
8
 
11
9
 
@@ -42,22 +40,22 @@ class TraceStorage(ABC):
42
40
  pass
43
41
 
44
42
  @abstractmethod
45
- async def query_traces(self, query: str, params: dict[str, Any] = None) -> pd.DataFrame:
46
- """Execute a query and return results as DataFrame.
43
+ async def query_traces(self, query: str, params: dict[str, Any] | None = None) -> Any:
44
+ """Execute a query and return results.
47
45
 
48
46
  Args:
49
47
  query: The SQL query to execute
50
48
  params: Optional query parameters
51
49
 
52
50
  Returns:
53
- Query results as a DataFrame
51
+ Query results as a DataFrame-like object or list of dict records
54
52
  """
55
53
  pass
56
54
 
57
55
  @abstractmethod
58
56
  async def get_model_usage(
59
- self, start_date: datetime = None, end_date: datetime = None, model_name: str = None
60
- ) -> pd.DataFrame:
57
+ self, start_date: datetime | None = None, end_date: datetime | None = None, model_name: str | None = None
58
+ ) -> Any:
61
59
  """Get model usage statistics.
62
60
 
63
61
  Args:
@@ -66,7 +64,7 @@ class TraceStorage(ABC):
66
64
  model_name: Optional model name filter
67
65
 
68
66
  Returns:
69
- Model usage statistics as a DataFrame
67
+ Model usage statistics as a DataFrame-like object or list of dict records
70
68
  """
71
69
  pass
72
70
 
@@ -92,8 +90,8 @@ class TraceStorage(ABC):
92
90
  self,
93
91
  experiment_id: str,
94
92
  name: str,
95
- description: str = None,
96
- configuration: dict[str, Any] = None,
93
+ description: str | None = None,
94
+ configuration: dict[str, Any] | None = None,
97
95
  ) -> str:
98
96
  """Create a new experiment."""
99
97
  raise NotImplementedError("Experiment management not supported by this backend")
@@ -103,14 +101,14 @@ class TraceStorage(ABC):
103
101
  raise NotImplementedError("Experiment management not supported by this backend")
104
102
 
105
103
  async def get_sessions_by_experiment(
106
- self, experiment_id: str, limit: int = None
104
+ self, experiment_id: str, limit: int | None = None
107
105
  ) -> list[dict[str, Any]]:
108
106
  """Get all sessions for an experiment."""
109
107
  raise NotImplementedError("Experiment management not supported by this backend")
110
108
 
111
109
  # Batch operations
112
110
  async def batch_insert_sessions(
113
- self, traces: list[SessionTrace], batch_size: int = 1000
111
+ self, traces: list[SessionTrace], batch_size: int | None = 1000
114
112
  ) -> list[str]:
115
113
  """Batch insert multiple session traces.
116
114