synth-ai 0.2.8.dev12__py3-none-any.whl → 0.2.9.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/api/train/__init__.py +5 -0
- synth_ai/api/train/builders.py +165 -0
- synth_ai/api/train/cli.py +450 -0
- synth_ai/api/train/config_finder.py +168 -0
- synth_ai/api/train/env_resolver.py +302 -0
- synth_ai/api/train/pollers.py +66 -0
- synth_ai/api/train/task_app.py +193 -0
- synth_ai/api/train/utils.py +232 -0
- synth_ai/cli/__init__.py +23 -0
- synth_ai/cli/rl_demo.py +18 -6
- synth_ai/cli/root.py +38 -6
- synth_ai/cli/task_apps.py +1107 -0
- synth_ai/demo_registry.py +258 -0
- synth_ai/demos/core/cli.py +147 -111
- synth_ai/demos/demo_task_apps/__init__.py +7 -1
- synth_ai/demos/demo_task_apps/math/config.toml +55 -110
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +157 -21
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +39 -0
- synth_ai/task/__init__.py +94 -1
- synth_ai/task/apps/__init__.py +88 -0
- synth_ai/task/apps/grpo_crafter.py +438 -0
- synth_ai/task/apps/math_single_step.py +852 -0
- synth_ai/task/auth.py +153 -0
- synth_ai/task/client.py +165 -0
- synth_ai/task/contracts.py +29 -14
- synth_ai/task/datasets.py +105 -0
- synth_ai/task/errors.py +49 -0
- synth_ai/task/json.py +77 -0
- synth_ai/task/proxy.py +258 -0
- synth_ai/task/rubrics.py +212 -0
- synth_ai/task/server.py +398 -0
- synth_ai/task/tracing_utils.py +79 -0
- synth_ai/task/vendors.py +61 -0
- synth_ai/tracing_v3/session_tracer.py +13 -5
- synth_ai/tracing_v3/storage/base.py +10 -12
- synth_ai/tracing_v3/turso/manager.py +20 -6
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/METADATA +3 -2
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/RECORD +42 -18
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/top_level.txt +0 -0
synth_ai/task/auth.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Authentication helpers shared by Task Apps."""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Iterable, Optional, Any, Set
|
|
7
|
+
|
|
8
|
+
from .errors import http_exception
|
|
9
|
+
|
|
10
|
+
_API_KEY_ENV = "ENVIRONMENT_API_KEY"
|
|
11
|
+
_DEV_API_KEY_ENVS = ("dev_environment_api_key", "DEV_ENVIRONMENT_API_KEY")
|
|
12
|
+
_API_KEY_HEADER = "x-api-key"
|
|
13
|
+
_API_KEYS_HEADER = "x-api-keys"
|
|
14
|
+
_AUTH_HEADER = "authorization"
|
|
15
|
+
_API_KEY_ALIASES_ENV = "ENVIRONMENT_API_KEY_ALIASES" # comma-separated list of additional valid keys
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _mask(value: str, *, prefix: int = 4) -> str:
|
|
19
|
+
if not value:
|
|
20
|
+
return "<empty>"
|
|
21
|
+
visible = value[:prefix]
|
|
22
|
+
return f"{visible}{'…' if len(value) > prefix else ''}"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def normalize_environment_api_key() -> Optional[str]:
|
|
26
|
+
"""Ensure `ENVIRONMENT_API_KEY` is populated from dev fallbacks.
|
|
27
|
+
|
|
28
|
+
Returns the resolved key (if any) so callers can branch on configuration.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
key = os.getenv(_API_KEY_ENV)
|
|
32
|
+
if key:
|
|
33
|
+
return key
|
|
34
|
+
for env in _DEV_API_KEY_ENVS:
|
|
35
|
+
candidate = os.getenv(env)
|
|
36
|
+
if candidate:
|
|
37
|
+
os.environ[_API_KEY_ENV] = candidate
|
|
38
|
+
print(
|
|
39
|
+
f"[task:auth] {_API_KEY_ENV} set from {env} (prefix={_mask(candidate)})",
|
|
40
|
+
flush=True,
|
|
41
|
+
)
|
|
42
|
+
return candidate
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def allowed_environment_api_keys() -> Set[str]:
|
|
47
|
+
"""Return the set of valid environment API keys for this Task App.
|
|
48
|
+
|
|
49
|
+
Includes:
|
|
50
|
+
- The primary ENVIRONMENT_API_KEY (normalized from dev fallbacks if needed)
|
|
51
|
+
- Any comma-separated aliases from ENVIRONMENT_API_KEY_ALIASES
|
|
52
|
+
"""
|
|
53
|
+
keys: set[str] = set()
|
|
54
|
+
primary = normalize_environment_api_key()
|
|
55
|
+
if primary:
|
|
56
|
+
keys.add(primary)
|
|
57
|
+
aliases = (os.getenv(_API_KEY_ALIASES_ENV) or "").strip()
|
|
58
|
+
if aliases:
|
|
59
|
+
for part in aliases.split(","):
|
|
60
|
+
trimmed = part.strip()
|
|
61
|
+
if trimmed:
|
|
62
|
+
keys.add(trimmed)
|
|
63
|
+
return keys
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _header_values(request: Any, header: str) -> Iterable[str]:
|
|
67
|
+
header_lower = header.lower()
|
|
68
|
+
if request is None:
|
|
69
|
+
return []
|
|
70
|
+
headers = getattr(request, "headers", None)
|
|
71
|
+
if headers:
|
|
72
|
+
raw = headers.get(header) or headers.get(header_lower)
|
|
73
|
+
if raw is not None:
|
|
74
|
+
return [raw]
|
|
75
|
+
if isinstance(request, dict):
|
|
76
|
+
raw = request.get(header) or request.get(header_lower)
|
|
77
|
+
if raw is not None:
|
|
78
|
+
return [raw]
|
|
79
|
+
# Support passing explicit header dict via keyword arg on FastAPI route handlers
|
|
80
|
+
for attr in ("headers", "state"):
|
|
81
|
+
maybe = getattr(request, attr, None)
|
|
82
|
+
if isinstance(maybe, dict):
|
|
83
|
+
raw = maybe.get(header) or maybe.get(header_lower)
|
|
84
|
+
if raw is not None:
|
|
85
|
+
return [raw]
|
|
86
|
+
return []
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _split_csv(values: Iterable[str]) -> list[str]:
|
|
90
|
+
seen: list[str] = []
|
|
91
|
+
for v in values:
|
|
92
|
+
if not isinstance(v, str):
|
|
93
|
+
continue
|
|
94
|
+
for part in v.split(","):
|
|
95
|
+
trimmed = part.strip()
|
|
96
|
+
if trimmed:
|
|
97
|
+
seen.append(trimmed)
|
|
98
|
+
return seen
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def is_api_key_header_authorized(request: Any) -> bool:
|
|
102
|
+
"""Return True if any header-provided key matches any allowed environment key."""
|
|
103
|
+
|
|
104
|
+
allowed = allowed_environment_api_keys()
|
|
105
|
+
if not allowed:
|
|
106
|
+
return False
|
|
107
|
+
single = list(_header_values(request, _API_KEY_HEADER))
|
|
108
|
+
multi = list(_header_values(request, _API_KEYS_HEADER))
|
|
109
|
+
auths = list(_header_values(request, _AUTH_HEADER))
|
|
110
|
+
bearer: list[str] = []
|
|
111
|
+
for a in auths:
|
|
112
|
+
if isinstance(a, str) and a.lower().startswith("bearer "):
|
|
113
|
+
bearer.append(a.split(" ", 1)[1].strip())
|
|
114
|
+
candidates = _split_csv(single + multi + bearer)
|
|
115
|
+
return any(candidate in allowed for candidate in candidates)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def require_api_key_dependency(request: Any) -> None:
|
|
119
|
+
"""FastAPI dependency enforcing Task App authentication headers."""
|
|
120
|
+
|
|
121
|
+
allowed = allowed_environment_api_keys()
|
|
122
|
+
if not allowed:
|
|
123
|
+
raise http_exception(503, "missing_environment_api_key", "ENVIRONMENT_API_KEY is not configured")
|
|
124
|
+
# Build candidate list for verbose diagnostics
|
|
125
|
+
single = list(_header_values(request, _API_KEY_HEADER))
|
|
126
|
+
multi = list(_header_values(request, _API_KEYS_HEADER))
|
|
127
|
+
auths = list(_header_values(request, _AUTH_HEADER))
|
|
128
|
+
bearer: list[str] = []
|
|
129
|
+
for a in auths:
|
|
130
|
+
if isinstance(a, str) and a.lower().startswith("bearer "):
|
|
131
|
+
bearer.append(a.split(" ", 1)[1].strip())
|
|
132
|
+
candidates = _split_csv(single + multi + bearer)
|
|
133
|
+
if not any(candidate in allowed for candidate in candidates):
|
|
134
|
+
try:
|
|
135
|
+
print({
|
|
136
|
+
"task_auth_failed": True,
|
|
137
|
+
"allowed_first15": [k[:15] for k in allowed],
|
|
138
|
+
"allowed_count": len(allowed),
|
|
139
|
+
"got_first15": [c[:15] for c in candidates],
|
|
140
|
+
"got_lens": [len(c) for c in candidates],
|
|
141
|
+
"have_x_api_key": bool(single),
|
|
142
|
+
"have_x_api_keys": bool(multi),
|
|
143
|
+
"have_authorization": bool(auths),
|
|
144
|
+
}, flush=True)
|
|
145
|
+
except Exception:
|
|
146
|
+
pass
|
|
147
|
+
# Use 400 to make failures unmistakable during preflight
|
|
148
|
+
raise http_exception(400, "unauthorised", "API key missing or invalid", extra={
|
|
149
|
+
"allowed_first15": [k[:15] for k in allowed],
|
|
150
|
+
"allowed_count": len(allowed),
|
|
151
|
+
"got_first15": [c[:15] for c in candidates],
|
|
152
|
+
"got_lens": [len(c) for c in candidates],
|
|
153
|
+
})
|
synth_ai/task/client.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Async HTTP client for interacting with Task Apps."""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import Any, Dict, Iterable, List, Optional
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
from .contracts import RolloutRequest, RolloutResponse, TaskInfo
|
|
13
|
+
from .json import to_jsonable
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _prepare_payload(payload: Any) -> Any:
|
|
17
|
+
if payload is None:
|
|
18
|
+
return None
|
|
19
|
+
if isinstance(payload, BaseModel):
|
|
20
|
+
return payload.model_dump(mode="json", by_alias=True)
|
|
21
|
+
return to_jsonable(payload)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TaskAppClient:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
base_url: str,
|
|
28
|
+
api_key: str | None = None,
|
|
29
|
+
*,
|
|
30
|
+
timeout: float = 600.0,
|
|
31
|
+
retries: int = 3,
|
|
32
|
+
) -> None:
|
|
33
|
+
self.base_url = base_url.rstrip("/")
|
|
34
|
+
self.api_key = api_key
|
|
35
|
+
self.timeout = timeout
|
|
36
|
+
self.retries = max(1, retries)
|
|
37
|
+
self._client: httpx.AsyncClient | None = None
|
|
38
|
+
self.env = _TaskAppEnvironmentClient(self)
|
|
39
|
+
|
|
40
|
+
async def __aenter__(self) -> "TaskAppClient":
|
|
41
|
+
await self._ensure_client()
|
|
42
|
+
return self
|
|
43
|
+
|
|
44
|
+
async def __aexit__(self, exc_type, exc, tb) -> None:
|
|
45
|
+
await self.aclose()
|
|
46
|
+
|
|
47
|
+
async def _ensure_client(self) -> httpx.AsyncClient:
|
|
48
|
+
if self._client is None:
|
|
49
|
+
self._client = httpx.AsyncClient(
|
|
50
|
+
base_url=self.base_url,
|
|
51
|
+
timeout=httpx.Timeout(self.timeout),
|
|
52
|
+
follow_redirects=True,
|
|
53
|
+
)
|
|
54
|
+
return self._client
|
|
55
|
+
|
|
56
|
+
def _headers(self) -> Dict[str, str]:
|
|
57
|
+
headers: Dict[str, str] = {}
|
|
58
|
+
# Primary key
|
|
59
|
+
primary = (self.api_key or "").strip()
|
|
60
|
+
if primary:
|
|
61
|
+
headers["X-API-Key"] = primary
|
|
62
|
+
# Also set Authorization for clients that read bearer tokens
|
|
63
|
+
headers.setdefault("Authorization", f"Bearer {primary}")
|
|
64
|
+
# Include ALL available environment keys via CSV in X-API-Keys
|
|
65
|
+
keys: list[str] = []
|
|
66
|
+
if primary:
|
|
67
|
+
keys.append(primary)
|
|
68
|
+
aliases = (os.getenv("ENVIRONMENT_API_KEY_ALIASES") or "").strip()
|
|
69
|
+
if aliases:
|
|
70
|
+
for part in aliases.split(","):
|
|
71
|
+
trimmed = part.strip()
|
|
72
|
+
if trimmed and trimmed not in keys:
|
|
73
|
+
keys.append(trimmed)
|
|
74
|
+
if keys:
|
|
75
|
+
headers["X-API-Keys"] = ",".join(keys)
|
|
76
|
+
return headers
|
|
77
|
+
|
|
78
|
+
async def aclose(self) -> None:
|
|
79
|
+
if self._client is not None:
|
|
80
|
+
await self._client.aclose()
|
|
81
|
+
self._client = None
|
|
82
|
+
|
|
83
|
+
async def _request(
|
|
84
|
+
self,
|
|
85
|
+
method: str,
|
|
86
|
+
path: str,
|
|
87
|
+
*,
|
|
88
|
+
params: Optional[Dict[str, Any] | List[tuple[str, Any]]] = None,
|
|
89
|
+
json_payload: Any = None,
|
|
90
|
+
) -> httpx.Response:
|
|
91
|
+
client = await self._ensure_client()
|
|
92
|
+
payload = _prepare_payload(json_payload)
|
|
93
|
+
headers = self._headers()
|
|
94
|
+
last_exc: Exception | None = None
|
|
95
|
+
for attempt in range(self.retries):
|
|
96
|
+
try:
|
|
97
|
+
response = await client.request(
|
|
98
|
+
method,
|
|
99
|
+
path,
|
|
100
|
+
headers=headers,
|
|
101
|
+
params=params,
|
|
102
|
+
json=payload,
|
|
103
|
+
)
|
|
104
|
+
response.raise_for_status()
|
|
105
|
+
return response
|
|
106
|
+
except httpx.HTTPStatusError as exc:
|
|
107
|
+
if 500 <= exc.response.status_code < 600 and attempt + 1 < self.retries:
|
|
108
|
+
await asyncio.sleep(0.1 * (attempt + 1))
|
|
109
|
+
last_exc = exc
|
|
110
|
+
continue
|
|
111
|
+
raise
|
|
112
|
+
except httpx.HTTPError as exc:
|
|
113
|
+
last_exc = exc
|
|
114
|
+
if attempt + 1 >= self.retries:
|
|
115
|
+
raise
|
|
116
|
+
await asyncio.sleep(0.1 * (attempt + 1))
|
|
117
|
+
if last_exc: # pragma: no cover - defensive
|
|
118
|
+
raise last_exc
|
|
119
|
+
raise RuntimeError("Unreachable code in TaskAppClient._request")
|
|
120
|
+
|
|
121
|
+
async def health(self) -> Dict[str, Any]:
|
|
122
|
+
response = await self._request("GET", "/health")
|
|
123
|
+
return response.json()
|
|
124
|
+
|
|
125
|
+
async def info(self) -> Dict[str, Any]:
|
|
126
|
+
response = await self._request("GET", "/info")
|
|
127
|
+
return response.json()
|
|
128
|
+
|
|
129
|
+
async def task_info(self, seeds: list[int] | None = None) -> TaskInfo | list[TaskInfo]:
|
|
130
|
+
params: Optional[List[tuple[str, Any]]] = None
|
|
131
|
+
if seeds:
|
|
132
|
+
params = [("seed", seed) for seed in seeds]
|
|
133
|
+
response = await self._request("GET", "/task_info", params=params)
|
|
134
|
+
data = response.json()
|
|
135
|
+
if isinstance(data, list):
|
|
136
|
+
return [TaskInfo.model_validate(item) for item in data]
|
|
137
|
+
return TaskInfo.model_validate(data)
|
|
138
|
+
|
|
139
|
+
async def rollout(self, request: RolloutRequest) -> RolloutResponse:
|
|
140
|
+
response = await self._request("POST", "/rollout", json_payload=request)
|
|
141
|
+
data = response.json()
|
|
142
|
+
return RolloutResponse.model_validate(data)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class _TaskAppEnvironmentClient:
|
|
146
|
+
def __init__(self, client: TaskAppClient) -> None:
|
|
147
|
+
self._client = client
|
|
148
|
+
|
|
149
|
+
async def initialize(self, env_name: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
150
|
+
response = await self._client._request(
|
|
151
|
+
"POST", f"/env/{env_name}/initialize", json_payload=payload
|
|
152
|
+
)
|
|
153
|
+
return response.json()
|
|
154
|
+
|
|
155
|
+
async def step(self, env_name: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
156
|
+
response = await self._client._request(
|
|
157
|
+
"POST", f"/env/{env_name}/step", json_payload=payload
|
|
158
|
+
)
|
|
159
|
+
return response.json()
|
|
160
|
+
|
|
161
|
+
async def terminate(self, env_name: str, payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
|
162
|
+
response = await self._client._request(
|
|
163
|
+
"POST", f"/env/{env_name}/terminate", json_payload=payload or {}
|
|
164
|
+
)
|
|
165
|
+
return response.json()
|
synth_ai/task/contracts.py
CHANGED
|
@@ -1,31 +1,27 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Optional, Any, Dict, List
|
|
5
|
-
from pydantic import BaseModel
|
|
4
|
+
from typing import Optional, Any, Dict, List, Literal
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
@dataclass(frozen=True)
|
|
9
9
|
class TaskAppEndpoints:
|
|
10
10
|
"""Canonical Task App endpoint shapes used by RL trainers.
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
- Environment lifecycle:
|
|
17
|
-
• POST /env/{env_name}/initialize → { env_id, observation }
|
|
18
|
-
• POST /env/{env_name}/step → { observation, reward, done, info }
|
|
19
|
-
• POST /env/{env_name}/terminate → { ok: true }
|
|
20
|
-
- Rollout (optional, unified schema):
|
|
21
|
-
• POST /rollout → { run_id, trajectories[], metrics, ... }
|
|
22
|
-
- Proxy (optional):
|
|
23
|
-
• POST /proxy/v1/chat/completions (for direct OpenAI calls from Task App)
|
|
12
|
+
Task Apps run as lightweight HTTP services (often on Modal) that expose a
|
|
13
|
+
consistent set of endpoints for health, metadata, environment lifecycle,
|
|
14
|
+
rollouts, and optional proxy access to vendor models. The endpoint strings
|
|
15
|
+
defined here act as defaults and documentation for clients.
|
|
24
16
|
"""
|
|
25
17
|
|
|
18
|
+
root: str = "/"
|
|
26
19
|
health: str = "/health"
|
|
20
|
+
info: str = "/info"
|
|
21
|
+
task_info: str = "/task_info"
|
|
27
22
|
rollout: str = "/rollout"
|
|
28
23
|
proxy_chat_completions: str = "/proxy/v1/chat/completions"
|
|
24
|
+
proxy_groq_chat_completions: str = "/proxy/groq/v1/chat/completions"
|
|
29
25
|
env_initialize: str = "/env/{env_name}/initialize"
|
|
30
26
|
env_step: str = "/env/{env_name}/step"
|
|
31
27
|
env_terminate: str = "/env/{env_name}/terminate"
|
|
@@ -67,6 +63,8 @@ class RolloutRecordConfig(BaseModel):
|
|
|
67
63
|
trajectories: bool = True
|
|
68
64
|
logprobs: bool = False
|
|
69
65
|
value: bool = False
|
|
66
|
+
return_trace: bool = False
|
|
67
|
+
trace_format: Literal["compact", "full"] = "compact"
|
|
70
68
|
|
|
71
69
|
|
|
72
70
|
class RolloutSafetyConfig(BaseModel):
|
|
@@ -108,6 +106,9 @@ class RolloutMetrics(BaseModel):
|
|
|
108
106
|
mean_return: float
|
|
109
107
|
num_steps: int
|
|
110
108
|
num_episodes: int = 0
|
|
109
|
+
outcome_score: Optional[float] = None
|
|
110
|
+
events_score: Optional[float] = None
|
|
111
|
+
details: Dict[str, Any] = Field(default_factory=dict)
|
|
111
112
|
|
|
112
113
|
|
|
113
114
|
class RolloutResponse(BaseModel):
|
|
@@ -117,4 +118,18 @@ class RolloutResponse(BaseModel):
|
|
|
117
118
|
metrics: RolloutMetrics
|
|
118
119
|
aborted: bool = False
|
|
119
120
|
ops_executed: int = 0
|
|
121
|
+
trace: Dict[str, Any] | None = None
|
|
120
122
|
|
|
123
|
+
|
|
124
|
+
class TaskInfo(BaseModel):
|
|
125
|
+
"""Static metadata describing the capabilities of a Task App task."""
|
|
126
|
+
|
|
127
|
+
task: Dict[str, Any]
|
|
128
|
+
environments: List[str]
|
|
129
|
+
action_space: Dict[str, Any]
|
|
130
|
+
observation: Dict[str, Any]
|
|
131
|
+
dataset: Dict[str, Any]
|
|
132
|
+
rubric: Dict[str, Any]
|
|
133
|
+
inference: Dict[str, Any]
|
|
134
|
+
capabilities: Dict[str, Any]
|
|
135
|
+
limits: Dict[str, Any]
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Dataset registry and helpers shared by Task Apps."""
|
|
4
|
+
|
|
5
|
+
from typing import Any, Callable, Dict, Hashable, Tuple
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field, field_validator
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TaskDatasetSpec(BaseModel):
|
|
11
|
+
"""Declarative metadata describing a dataset that a Task App exposes."""
|
|
12
|
+
|
|
13
|
+
id: str
|
|
14
|
+
name: str
|
|
15
|
+
version: str | None = None
|
|
16
|
+
splits: list[str] = Field(default_factory=list)
|
|
17
|
+
default_split: str | None = None
|
|
18
|
+
cardinality: int | None = None
|
|
19
|
+
description: str | None = None
|
|
20
|
+
|
|
21
|
+
@field_validator("default_split")
|
|
22
|
+
@classmethod
|
|
23
|
+
def _validate_default_split(cls, value: str | None, info):
|
|
24
|
+
values = info.data if hasattr(info, "data") else {} # type: ignore[attr-defined]
|
|
25
|
+
if value and value not in (values.get("splits") or []):
|
|
26
|
+
raise ValueError("default_split must be one of splits when provided")
|
|
27
|
+
return value
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
RegistryLoader = Callable[[TaskDatasetSpec], Any]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TaskDatasetRegistry:
|
|
34
|
+
"""Lightweight registry mapping dataset specs to loader callables."""
|
|
35
|
+
|
|
36
|
+
def __init__(self) -> None:
|
|
37
|
+
self._entries: Dict[str, Tuple[TaskDatasetSpec, RegistryLoader, bool]] = {}
|
|
38
|
+
self._cache: Dict[Hashable, Any] = {}
|
|
39
|
+
|
|
40
|
+
def register(self, spec: TaskDatasetSpec, loader: RegistryLoader, *, cache: bool = True) -> None:
|
|
41
|
+
"""Register a dataset loader and its metadata."""
|
|
42
|
+
|
|
43
|
+
self._entries[spec.id] = (spec, loader, cache)
|
|
44
|
+
|
|
45
|
+
def describe(self, dataset_id: str) -> TaskDatasetSpec:
|
|
46
|
+
if dataset_id not in self._entries:
|
|
47
|
+
raise KeyError(f"Dataset not registered: {dataset_id}")
|
|
48
|
+
spec, _, _ = self._entries[dataset_id]
|
|
49
|
+
return spec
|
|
50
|
+
|
|
51
|
+
def list(self) -> list[TaskDatasetSpec]:
|
|
52
|
+
return [entry[0] for entry in self._entries.values()]
|
|
53
|
+
|
|
54
|
+
def get(self, spec: TaskDatasetSpec | str) -> Any:
|
|
55
|
+
"""Return dataset materialisation (with optional caching)."""
|
|
56
|
+
|
|
57
|
+
if isinstance(spec, str):
|
|
58
|
+
if spec not in self._entries:
|
|
59
|
+
raise KeyError(f"Dataset not registered: {spec}")
|
|
60
|
+
base_spec, loader, cache_enabled = self._entries[spec]
|
|
61
|
+
effective_spec = base_spec
|
|
62
|
+
else:
|
|
63
|
+
if spec.id not in self._entries:
|
|
64
|
+
raise KeyError(f"Dataset not registered: {spec.id}")
|
|
65
|
+
base_spec, loader, cache_enabled = self._entries[spec.id]
|
|
66
|
+
effective_spec = base_spec.model_copy(update=spec.model_dump(exclude_unset=True))
|
|
67
|
+
|
|
68
|
+
cache_key: Hashable = (
|
|
69
|
+
effective_spec.id,
|
|
70
|
+
effective_spec.version,
|
|
71
|
+
effective_spec.default_split,
|
|
72
|
+
)
|
|
73
|
+
if cache_enabled:
|
|
74
|
+
if cache_key not in self._cache:
|
|
75
|
+
self._cache[cache_key] = loader(effective_spec)
|
|
76
|
+
return self._cache[cache_key]
|
|
77
|
+
return loader(effective_spec)
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def ensure_split(spec: TaskDatasetSpec, split: str | None) -> str:
|
|
81
|
+
"""Validate that `split` exists on the spec; return a concrete split."""
|
|
82
|
+
|
|
83
|
+
if not spec.splits:
|
|
84
|
+
return split or spec.default_split or "default"
|
|
85
|
+
if split is None:
|
|
86
|
+
if spec.default_split:
|
|
87
|
+
return spec.default_split
|
|
88
|
+
raise ValueError(f"split must be provided for dataset {spec.id}")
|
|
89
|
+
if split not in spec.splits:
|
|
90
|
+
raise ValueError(f"Unknown split '{split}' for dataset {spec.id}")
|
|
91
|
+
return split
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def normalise_seed(seed: Any, *, cardinality: int | None = None) -> int:
|
|
95
|
+
"""Normalise arbitrary seed input into a bounded non-negative integer."""
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
value = int(seed)
|
|
99
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
100
|
+
raise ValueError(f"Seed must be convertible to int (got {seed!r})") from exc
|
|
101
|
+
if value < 0:
|
|
102
|
+
value = abs(value)
|
|
103
|
+
if cardinality and cardinality > 0:
|
|
104
|
+
value = value % cardinality
|
|
105
|
+
return value
|
synth_ai/task/errors.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Error helpers used across Task App implementations."""
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
|
|
7
|
+
from .json import to_jsonable
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def error_payload(code: str, message: str, *, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
11
|
+
payload: Dict[str, Any] = {"error": {"code": code, "message": message}}
|
|
12
|
+
if extra:
|
|
13
|
+
payload["error"].update(extra)
|
|
14
|
+
return payload
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def http_exception(
|
|
18
|
+
status_code: int,
|
|
19
|
+
code: str,
|
|
20
|
+
message: str,
|
|
21
|
+
*,
|
|
22
|
+
extra: Optional[Dict[str, Any]] = None,
|
|
23
|
+
headers: Optional[Dict[str, str]] = None,
|
|
24
|
+
):
|
|
25
|
+
try:
|
|
26
|
+
from fastapi import HTTPException # type: ignore
|
|
27
|
+
except Exception as exc: # pragma: no cover - FastAPI not installed
|
|
28
|
+
raise RuntimeError("fastapi must be installed to raise HTTPException") from exc
|
|
29
|
+
|
|
30
|
+
payload = error_payload(code, message, extra=extra)
|
|
31
|
+
return HTTPException(status_code=status_code, detail=to_jsonable(payload), headers=headers)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def json_error_response(
|
|
35
|
+
status_code: int,
|
|
36
|
+
code: str,
|
|
37
|
+
message: str,
|
|
38
|
+
*,
|
|
39
|
+
extra: Optional[Dict[str, Any]] = None,
|
|
40
|
+
headers: Optional[Dict[str, str]] = None,
|
|
41
|
+
):
|
|
42
|
+
try:
|
|
43
|
+
from fastapi.responses import JSONResponse # type: ignore
|
|
44
|
+
except Exception as exc: # pragma: no cover - FastAPI not installed
|
|
45
|
+
raise RuntimeError("fastapi must be installed to build JSONResponse") from exc
|
|
46
|
+
|
|
47
|
+
payload = error_payload(code, message, extra=extra)
|
|
48
|
+
return JSONResponse(status_code=status_code, content=to_jsonable(payload), headers=headers)
|
|
49
|
+
|
synth_ai/task/json.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Shared JSON sanitisation helpers for Task Apps."""
|
|
4
|
+
|
|
5
|
+
from collections.abc import Mapping, Sequence
|
|
6
|
+
from dataclasses import is_dataclass, asdict
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
try: # numpy is optional at runtime; degrade gracefully if absent
|
|
11
|
+
import numpy as _np # type: ignore
|
|
12
|
+
except Exception: # pragma: no cover - handled at runtime
|
|
13
|
+
_np = None # type: ignore
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _mask_numpy_array(arr: "_np.ndarray") -> str:
|
|
17
|
+
shape = getattr(arr, "shape", None)
|
|
18
|
+
dtype = getattr(arr, "dtype", None)
|
|
19
|
+
return f"<ndarray shape={shape} dtype={dtype}>"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def to_jsonable(value: Any) -> Any:
|
|
23
|
+
"""Convert `value` into structures compatible with JSON serialisation.
|
|
24
|
+
|
|
25
|
+
- numpy scalars are converted to their Python counterparts
|
|
26
|
+
- numpy arrays are represented by a compact descriptor string
|
|
27
|
+
- dataclasses, Enums, and pydantic models are unwrapped recursively
|
|
28
|
+
- sets and tuples are converted to lists
|
|
29
|
+
- non-serialisable objects fall back to `repr`
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
if value is None or isinstance(value, (str, bool, int, float)):
|
|
33
|
+
return value
|
|
34
|
+
|
|
35
|
+
# numpy scalars / arrays
|
|
36
|
+
if _np is not None:
|
|
37
|
+
if isinstance(value, (_np.integer,)):
|
|
38
|
+
return int(value)
|
|
39
|
+
if isinstance(value, (_np.floating,)):
|
|
40
|
+
return float(value)
|
|
41
|
+
if isinstance(value, (_np.bool_,)):
|
|
42
|
+
return bool(value)
|
|
43
|
+
if isinstance(value, (_np.ndarray,)):
|
|
44
|
+
return _mask_numpy_array(value)
|
|
45
|
+
|
|
46
|
+
if isinstance(value, Enum):
|
|
47
|
+
return to_jsonable(value.value)
|
|
48
|
+
|
|
49
|
+
if is_dataclass(value):
|
|
50
|
+
return to_jsonable(asdict(value))
|
|
51
|
+
|
|
52
|
+
# pydantic BaseModel / attrs objects
|
|
53
|
+
for attr in ("model_dump", "dict", "to_dict", "to_json"):
|
|
54
|
+
if hasattr(value, attr) and callable(getattr(value, attr, None)):
|
|
55
|
+
try:
|
|
56
|
+
dumped = getattr(value, attr)() # type: ignore[misc]
|
|
57
|
+
except TypeError:
|
|
58
|
+
dumped = getattr(value, attr)(exclude_none=False) # pragma: no cover
|
|
59
|
+
return to_jsonable(dumped)
|
|
60
|
+
|
|
61
|
+
if isinstance(value, Mapping):
|
|
62
|
+
return {str(k): to_jsonable(v) for k, v in value.items()}
|
|
63
|
+
|
|
64
|
+
if isinstance(value, (set, tuple)):
|
|
65
|
+
return [to_jsonable(v) for v in value]
|
|
66
|
+
|
|
67
|
+
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
|
68
|
+
return [to_jsonable(v) for v in value]
|
|
69
|
+
|
|
70
|
+
if isinstance(value, (bytes, bytearray)):
|
|
71
|
+
return f"<bytes len={len(value)}>"
|
|
72
|
+
|
|
73
|
+
if hasattr(value, "__dict__"):
|
|
74
|
+
return to_jsonable(vars(value))
|
|
75
|
+
|
|
76
|
+
return repr(value)
|
|
77
|
+
|