synth-ai 0.2.8.dev12__py3-none-any.whl → 0.2.9.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. synth_ai/api/train/__init__.py +5 -0
  2. synth_ai/api/train/builders.py +165 -0
  3. synth_ai/api/train/cli.py +450 -0
  4. synth_ai/api/train/config_finder.py +168 -0
  5. synth_ai/api/train/env_resolver.py +302 -0
  6. synth_ai/api/train/pollers.py +66 -0
  7. synth_ai/api/train/task_app.py +193 -0
  8. synth_ai/api/train/utils.py +232 -0
  9. synth_ai/cli/__init__.py +23 -0
  10. synth_ai/cli/rl_demo.py +18 -6
  11. synth_ai/cli/root.py +38 -6
  12. synth_ai/cli/task_apps.py +1107 -0
  13. synth_ai/demo_registry.py +258 -0
  14. synth_ai/demos/core/cli.py +147 -111
  15. synth_ai/demos/demo_task_apps/__init__.py +7 -1
  16. synth_ai/demos/demo_task_apps/math/config.toml +55 -110
  17. synth_ai/demos/demo_task_apps/math/modal_task_app.py +157 -21
  18. synth_ai/demos/demo_task_apps/math/task_app_entry.py +39 -0
  19. synth_ai/task/__init__.py +94 -1
  20. synth_ai/task/apps/__init__.py +88 -0
  21. synth_ai/task/apps/grpo_crafter.py +438 -0
  22. synth_ai/task/apps/math_single_step.py +852 -0
  23. synth_ai/task/auth.py +153 -0
  24. synth_ai/task/client.py +165 -0
  25. synth_ai/task/contracts.py +29 -14
  26. synth_ai/task/datasets.py +105 -0
  27. synth_ai/task/errors.py +49 -0
  28. synth_ai/task/json.py +77 -0
  29. synth_ai/task/proxy.py +258 -0
  30. synth_ai/task/rubrics.py +212 -0
  31. synth_ai/task/server.py +398 -0
  32. synth_ai/task/tracing_utils.py +79 -0
  33. synth_ai/task/vendors.py +61 -0
  34. synth_ai/tracing_v3/session_tracer.py +13 -5
  35. synth_ai/tracing_v3/storage/base.py +10 -12
  36. synth_ai/tracing_v3/turso/manager.py +20 -6
  37. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/METADATA +3 -2
  38. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/RECORD +42 -18
  39. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/WHEEL +0 -0
  40. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/entry_points.txt +0 -0
  41. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/licenses/LICENSE +0 -0
  42. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/top_level.txt +0 -0
synth_ai/task/auth.py ADDED
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ """Authentication helpers shared by Task Apps."""
4
+
5
+ import os
6
+ from typing import Iterable, Optional, Any, Set
7
+
8
+ from .errors import http_exception
9
+
10
+ _API_KEY_ENV = "ENVIRONMENT_API_KEY"
11
+ _DEV_API_KEY_ENVS = ("dev_environment_api_key", "DEV_ENVIRONMENT_API_KEY")
12
+ _API_KEY_HEADER = "x-api-key"
13
+ _API_KEYS_HEADER = "x-api-keys"
14
+ _AUTH_HEADER = "authorization"
15
+ _API_KEY_ALIASES_ENV = "ENVIRONMENT_API_KEY_ALIASES" # comma-separated list of additional valid keys
16
+
17
+
18
+ def _mask(value: str, *, prefix: int = 4) -> str:
19
+ if not value:
20
+ return "<empty>"
21
+ visible = value[:prefix]
22
+ return f"{visible}{'…' if len(value) > prefix else ''}"
23
+
24
+
25
+ def normalize_environment_api_key() -> Optional[str]:
26
+ """Ensure `ENVIRONMENT_API_KEY` is populated from dev fallbacks.
27
+
28
+ Returns the resolved key (if any) so callers can branch on configuration.
29
+ """
30
+
31
+ key = os.getenv(_API_KEY_ENV)
32
+ if key:
33
+ return key
34
+ for env in _DEV_API_KEY_ENVS:
35
+ candidate = os.getenv(env)
36
+ if candidate:
37
+ os.environ[_API_KEY_ENV] = candidate
38
+ print(
39
+ f"[task:auth] {_API_KEY_ENV} set from {env} (prefix={_mask(candidate)})",
40
+ flush=True,
41
+ )
42
+ return candidate
43
+ return None
44
+
45
+
46
+ def allowed_environment_api_keys() -> Set[str]:
47
+ """Return the set of valid environment API keys for this Task App.
48
+
49
+ Includes:
50
+ - The primary ENVIRONMENT_API_KEY (normalized from dev fallbacks if needed)
51
+ - Any comma-separated aliases from ENVIRONMENT_API_KEY_ALIASES
52
+ """
53
+ keys: set[str] = set()
54
+ primary = normalize_environment_api_key()
55
+ if primary:
56
+ keys.add(primary)
57
+ aliases = (os.getenv(_API_KEY_ALIASES_ENV) or "").strip()
58
+ if aliases:
59
+ for part in aliases.split(","):
60
+ trimmed = part.strip()
61
+ if trimmed:
62
+ keys.add(trimmed)
63
+ return keys
64
+
65
+
66
+ def _header_values(request: Any, header: str) -> Iterable[str]:
67
+ header_lower = header.lower()
68
+ if request is None:
69
+ return []
70
+ headers = getattr(request, "headers", None)
71
+ if headers:
72
+ raw = headers.get(header) or headers.get(header_lower)
73
+ if raw is not None:
74
+ return [raw]
75
+ if isinstance(request, dict):
76
+ raw = request.get(header) or request.get(header_lower)
77
+ if raw is not None:
78
+ return [raw]
79
+ # Support passing explicit header dict via keyword arg on FastAPI route handlers
80
+ for attr in ("headers", "state"):
81
+ maybe = getattr(request, attr, None)
82
+ if isinstance(maybe, dict):
83
+ raw = maybe.get(header) or maybe.get(header_lower)
84
+ if raw is not None:
85
+ return [raw]
86
+ return []
87
+
88
+
89
+ def _split_csv(values: Iterable[str]) -> list[str]:
90
+ seen: list[str] = []
91
+ for v in values:
92
+ if not isinstance(v, str):
93
+ continue
94
+ for part in v.split(","):
95
+ trimmed = part.strip()
96
+ if trimmed:
97
+ seen.append(trimmed)
98
+ return seen
99
+
100
+
101
+ def is_api_key_header_authorized(request: Any) -> bool:
102
+ """Return True if any header-provided key matches any allowed environment key."""
103
+
104
+ allowed = allowed_environment_api_keys()
105
+ if not allowed:
106
+ return False
107
+ single = list(_header_values(request, _API_KEY_HEADER))
108
+ multi = list(_header_values(request, _API_KEYS_HEADER))
109
+ auths = list(_header_values(request, _AUTH_HEADER))
110
+ bearer: list[str] = []
111
+ for a in auths:
112
+ if isinstance(a, str) and a.lower().startswith("bearer "):
113
+ bearer.append(a.split(" ", 1)[1].strip())
114
+ candidates = _split_csv(single + multi + bearer)
115
+ return any(candidate in allowed for candidate in candidates)
116
+
117
+
118
+ def require_api_key_dependency(request: Any) -> None:
119
+ """FastAPI dependency enforcing Task App authentication headers."""
120
+
121
+ allowed = allowed_environment_api_keys()
122
+ if not allowed:
123
+ raise http_exception(503, "missing_environment_api_key", "ENVIRONMENT_API_KEY is not configured")
124
+ # Build candidate list for verbose diagnostics
125
+ single = list(_header_values(request, _API_KEY_HEADER))
126
+ multi = list(_header_values(request, _API_KEYS_HEADER))
127
+ auths = list(_header_values(request, _AUTH_HEADER))
128
+ bearer: list[str] = []
129
+ for a in auths:
130
+ if isinstance(a, str) and a.lower().startswith("bearer "):
131
+ bearer.append(a.split(" ", 1)[1].strip())
132
+ candidates = _split_csv(single + multi + bearer)
133
+ if not any(candidate in allowed for candidate in candidates):
134
+ try:
135
+ print({
136
+ "task_auth_failed": True,
137
+ "allowed_first15": [k[:15] for k in allowed],
138
+ "allowed_count": len(allowed),
139
+ "got_first15": [c[:15] for c in candidates],
140
+ "got_lens": [len(c) for c in candidates],
141
+ "have_x_api_key": bool(single),
142
+ "have_x_api_keys": bool(multi),
143
+ "have_authorization": bool(auths),
144
+ }, flush=True)
145
+ except Exception:
146
+ pass
147
+ # Use 400 to make failures unmistakable during preflight
148
+ raise http_exception(400, "unauthorised", "API key missing or invalid", extra={
149
+ "allowed_first15": [k[:15] for k in allowed],
150
+ "allowed_count": len(allowed),
151
+ "got_first15": [c[:15] for c in candidates],
152
+ "got_lens": [len(c) for c in candidates],
153
+ })
@@ -0,0 +1,165 @@
1
+ from __future__ import annotations
2
+
3
+ """Async HTTP client for interacting with Task Apps."""
4
+
5
+ import asyncio
6
+ from typing import Any, Dict, Iterable, List, Optional
7
+ import os
8
+
9
+ import httpx
10
+ from pydantic import BaseModel
11
+
12
+ from .contracts import RolloutRequest, RolloutResponse, TaskInfo
13
+ from .json import to_jsonable
14
+
15
+
16
+ def _prepare_payload(payload: Any) -> Any:
17
+ if payload is None:
18
+ return None
19
+ if isinstance(payload, BaseModel):
20
+ return payload.model_dump(mode="json", by_alias=True)
21
+ return to_jsonable(payload)
22
+
23
+
24
+ class TaskAppClient:
25
+ def __init__(
26
+ self,
27
+ base_url: str,
28
+ api_key: str | None = None,
29
+ *,
30
+ timeout: float = 600.0,
31
+ retries: int = 3,
32
+ ) -> None:
33
+ self.base_url = base_url.rstrip("/")
34
+ self.api_key = api_key
35
+ self.timeout = timeout
36
+ self.retries = max(1, retries)
37
+ self._client: httpx.AsyncClient | None = None
38
+ self.env = _TaskAppEnvironmentClient(self)
39
+
40
+ async def __aenter__(self) -> "TaskAppClient":
41
+ await self._ensure_client()
42
+ return self
43
+
44
+ async def __aexit__(self, exc_type, exc, tb) -> None:
45
+ await self.aclose()
46
+
47
+ async def _ensure_client(self) -> httpx.AsyncClient:
48
+ if self._client is None:
49
+ self._client = httpx.AsyncClient(
50
+ base_url=self.base_url,
51
+ timeout=httpx.Timeout(self.timeout),
52
+ follow_redirects=True,
53
+ )
54
+ return self._client
55
+
56
+ def _headers(self) -> Dict[str, str]:
57
+ headers: Dict[str, str] = {}
58
+ # Primary key
59
+ primary = (self.api_key or "").strip()
60
+ if primary:
61
+ headers["X-API-Key"] = primary
62
+ # Also set Authorization for clients that read bearer tokens
63
+ headers.setdefault("Authorization", f"Bearer {primary}")
64
+ # Include ALL available environment keys via CSV in X-API-Keys
65
+ keys: list[str] = []
66
+ if primary:
67
+ keys.append(primary)
68
+ aliases = (os.getenv("ENVIRONMENT_API_KEY_ALIASES") or "").strip()
69
+ if aliases:
70
+ for part in aliases.split(","):
71
+ trimmed = part.strip()
72
+ if trimmed and trimmed not in keys:
73
+ keys.append(trimmed)
74
+ if keys:
75
+ headers["X-API-Keys"] = ",".join(keys)
76
+ return headers
77
+
78
+ async def aclose(self) -> None:
79
+ if self._client is not None:
80
+ await self._client.aclose()
81
+ self._client = None
82
+
83
+ async def _request(
84
+ self,
85
+ method: str,
86
+ path: str,
87
+ *,
88
+ params: Optional[Dict[str, Any] | List[tuple[str, Any]]] = None,
89
+ json_payload: Any = None,
90
+ ) -> httpx.Response:
91
+ client = await self._ensure_client()
92
+ payload = _prepare_payload(json_payload)
93
+ headers = self._headers()
94
+ last_exc: Exception | None = None
95
+ for attempt in range(self.retries):
96
+ try:
97
+ response = await client.request(
98
+ method,
99
+ path,
100
+ headers=headers,
101
+ params=params,
102
+ json=payload,
103
+ )
104
+ response.raise_for_status()
105
+ return response
106
+ except httpx.HTTPStatusError as exc:
107
+ if 500 <= exc.response.status_code < 600 and attempt + 1 < self.retries:
108
+ await asyncio.sleep(0.1 * (attempt + 1))
109
+ last_exc = exc
110
+ continue
111
+ raise
112
+ except httpx.HTTPError as exc:
113
+ last_exc = exc
114
+ if attempt + 1 >= self.retries:
115
+ raise
116
+ await asyncio.sleep(0.1 * (attempt + 1))
117
+ if last_exc: # pragma: no cover - defensive
118
+ raise last_exc
119
+ raise RuntimeError("Unreachable code in TaskAppClient._request")
120
+
121
+ async def health(self) -> Dict[str, Any]:
122
+ response = await self._request("GET", "/health")
123
+ return response.json()
124
+
125
+ async def info(self) -> Dict[str, Any]:
126
+ response = await self._request("GET", "/info")
127
+ return response.json()
128
+
129
+ async def task_info(self, seeds: list[int] | None = None) -> TaskInfo | list[TaskInfo]:
130
+ params: Optional[List[tuple[str, Any]]] = None
131
+ if seeds:
132
+ params = [("seed", seed) for seed in seeds]
133
+ response = await self._request("GET", "/task_info", params=params)
134
+ data = response.json()
135
+ if isinstance(data, list):
136
+ return [TaskInfo.model_validate(item) for item in data]
137
+ return TaskInfo.model_validate(data)
138
+
139
+ async def rollout(self, request: RolloutRequest) -> RolloutResponse:
140
+ response = await self._request("POST", "/rollout", json_payload=request)
141
+ data = response.json()
142
+ return RolloutResponse.model_validate(data)
143
+
144
+
145
+ class _TaskAppEnvironmentClient:
146
+ def __init__(self, client: TaskAppClient) -> None:
147
+ self._client = client
148
+
149
+ async def initialize(self, env_name: str, payload: Dict[str, Any]) -> Dict[str, Any]:
150
+ response = await self._client._request(
151
+ "POST", f"/env/{env_name}/initialize", json_payload=payload
152
+ )
153
+ return response.json()
154
+
155
+ async def step(self, env_name: str, payload: Dict[str, Any]) -> Dict[str, Any]:
156
+ response = await self._client._request(
157
+ "POST", f"/env/{env_name}/step", json_payload=payload
158
+ )
159
+ return response.json()
160
+
161
+ async def terminate(self, env_name: str, payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
162
+ response = await self._client._request(
163
+ "POST", f"/env/{env_name}/terminate", json_payload=payload or {}
164
+ )
165
+ return response.json()
@@ -1,31 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Optional, Any, Dict, List
5
- from pydantic import BaseModel
4
+ from typing import Optional, Any, Dict, List, Literal
5
+ from pydantic import BaseModel, Field
6
6
 
7
7
 
8
8
  @dataclass(frozen=True)
9
9
  class TaskAppEndpoints:
10
10
  """Canonical Task App endpoint shapes used by RL trainers.
11
11
 
12
- The Task App is an HTTP service (often deployed on Modal) that exposes:
13
- - Health: GET /health
14
- Requires header X-API-Key (when ENVIRONMENT_API_KEY is configured)
15
- Returns { healthy: true }
16
- - Environment lifecycle:
17
- • POST /env/{env_name}/initialize → { env_id, observation }
18
- • POST /env/{env_name}/step → { observation, reward, done, info }
19
- • POST /env/{env_name}/terminate → { ok: true }
20
- - Rollout (optional, unified schema):
21
- • POST /rollout → { run_id, trajectories[], metrics, ... }
22
- - Proxy (optional):
23
- • POST /proxy/v1/chat/completions (for direct OpenAI calls from Task App)
12
+ Task Apps run as lightweight HTTP services (often on Modal) that expose a
13
+ consistent set of endpoints for health, metadata, environment lifecycle,
14
+ rollouts, and optional proxy access to vendor models. The endpoint strings
15
+ defined here act as defaults and documentation for clients.
24
16
  """
25
17
 
18
+ root: str = "/"
26
19
  health: str = "/health"
20
+ info: str = "/info"
21
+ task_info: str = "/task_info"
27
22
  rollout: str = "/rollout"
28
23
  proxy_chat_completions: str = "/proxy/v1/chat/completions"
24
+ proxy_groq_chat_completions: str = "/proxy/groq/v1/chat/completions"
29
25
  env_initialize: str = "/env/{env_name}/initialize"
30
26
  env_step: str = "/env/{env_name}/step"
31
27
  env_terminate: str = "/env/{env_name}/terminate"
@@ -67,6 +63,8 @@ class RolloutRecordConfig(BaseModel):
67
63
  trajectories: bool = True
68
64
  logprobs: bool = False
69
65
  value: bool = False
66
+ return_trace: bool = False
67
+ trace_format: Literal["compact", "full"] = "compact"
70
68
 
71
69
 
72
70
  class RolloutSafetyConfig(BaseModel):
@@ -108,6 +106,9 @@ class RolloutMetrics(BaseModel):
108
106
  mean_return: float
109
107
  num_steps: int
110
108
  num_episodes: int = 0
109
+ outcome_score: Optional[float] = None
110
+ events_score: Optional[float] = None
111
+ details: Dict[str, Any] = Field(default_factory=dict)
111
112
 
112
113
 
113
114
  class RolloutResponse(BaseModel):
@@ -117,4 +118,18 @@ class RolloutResponse(BaseModel):
117
118
  metrics: RolloutMetrics
118
119
  aborted: bool = False
119
120
  ops_executed: int = 0
121
+ trace: Dict[str, Any] | None = None
120
122
 
123
+
124
+ class TaskInfo(BaseModel):
125
+ """Static metadata describing the capabilities of a Task App task."""
126
+
127
+ task: Dict[str, Any]
128
+ environments: List[str]
129
+ action_space: Dict[str, Any]
130
+ observation: Dict[str, Any]
131
+ dataset: Dict[str, Any]
132
+ rubric: Dict[str, Any]
133
+ inference: Dict[str, Any]
134
+ capabilities: Dict[str, Any]
135
+ limits: Dict[str, Any]
@@ -0,0 +1,105 @@
1
+ from __future__ import annotations
2
+
3
+ """Dataset registry and helpers shared by Task Apps."""
4
+
5
+ from typing import Any, Callable, Dict, Hashable, Tuple
6
+
7
+ from pydantic import BaseModel, Field, field_validator
8
+
9
+
10
+ class TaskDatasetSpec(BaseModel):
11
+ """Declarative metadata describing a dataset that a Task App exposes."""
12
+
13
+ id: str
14
+ name: str
15
+ version: str | None = None
16
+ splits: list[str] = Field(default_factory=list)
17
+ default_split: str | None = None
18
+ cardinality: int | None = None
19
+ description: str | None = None
20
+
21
+ @field_validator("default_split")
22
+ @classmethod
23
+ def _validate_default_split(cls, value: str | None, info):
24
+ values = info.data if hasattr(info, "data") else {} # type: ignore[attr-defined]
25
+ if value and value not in (values.get("splits") or []):
26
+ raise ValueError("default_split must be one of splits when provided")
27
+ return value
28
+
29
+
30
+ RegistryLoader = Callable[[TaskDatasetSpec], Any]
31
+
32
+
33
+ class TaskDatasetRegistry:
34
+ """Lightweight registry mapping dataset specs to loader callables."""
35
+
36
+ def __init__(self) -> None:
37
+ self._entries: Dict[str, Tuple[TaskDatasetSpec, RegistryLoader, bool]] = {}
38
+ self._cache: Dict[Hashable, Any] = {}
39
+
40
+ def register(self, spec: TaskDatasetSpec, loader: RegistryLoader, *, cache: bool = True) -> None:
41
+ """Register a dataset loader and its metadata."""
42
+
43
+ self._entries[spec.id] = (spec, loader, cache)
44
+
45
+ def describe(self, dataset_id: str) -> TaskDatasetSpec:
46
+ if dataset_id not in self._entries:
47
+ raise KeyError(f"Dataset not registered: {dataset_id}")
48
+ spec, _, _ = self._entries[dataset_id]
49
+ return spec
50
+
51
+ def list(self) -> list[TaskDatasetSpec]:
52
+ return [entry[0] for entry in self._entries.values()]
53
+
54
+ def get(self, spec: TaskDatasetSpec | str) -> Any:
55
+ """Return dataset materialisation (with optional caching)."""
56
+
57
+ if isinstance(spec, str):
58
+ if spec not in self._entries:
59
+ raise KeyError(f"Dataset not registered: {spec}")
60
+ base_spec, loader, cache_enabled = self._entries[spec]
61
+ effective_spec = base_spec
62
+ else:
63
+ if spec.id not in self._entries:
64
+ raise KeyError(f"Dataset not registered: {spec.id}")
65
+ base_spec, loader, cache_enabled = self._entries[spec.id]
66
+ effective_spec = base_spec.model_copy(update=spec.model_dump(exclude_unset=True))
67
+
68
+ cache_key: Hashable = (
69
+ effective_spec.id,
70
+ effective_spec.version,
71
+ effective_spec.default_split,
72
+ )
73
+ if cache_enabled:
74
+ if cache_key not in self._cache:
75
+ self._cache[cache_key] = loader(effective_spec)
76
+ return self._cache[cache_key]
77
+ return loader(effective_spec)
78
+
79
+ @staticmethod
80
+ def ensure_split(spec: TaskDatasetSpec, split: str | None) -> str:
81
+ """Validate that `split` exists on the spec; return a concrete split."""
82
+
83
+ if not spec.splits:
84
+ return split or spec.default_split or "default"
85
+ if split is None:
86
+ if spec.default_split:
87
+ return spec.default_split
88
+ raise ValueError(f"split must be provided for dataset {spec.id}")
89
+ if split not in spec.splits:
90
+ raise ValueError(f"Unknown split '{split}' for dataset {spec.id}")
91
+ return split
92
+
93
+ @staticmethod
94
+ def normalise_seed(seed: Any, *, cardinality: int | None = None) -> int:
95
+ """Normalise arbitrary seed input into a bounded non-negative integer."""
96
+
97
+ try:
98
+ value = int(seed)
99
+ except Exception as exc: # pragma: no cover - defensive
100
+ raise ValueError(f"Seed must be convertible to int (got {seed!r})") from exc
101
+ if value < 0:
102
+ value = abs(value)
103
+ if cardinality and cardinality > 0:
104
+ value = value % cardinality
105
+ return value
@@ -0,0 +1,49 @@
1
+ from __future__ import annotations
2
+
3
+ """Error helpers used across Task App implementations."""
4
+
5
+ from typing import Any, Dict, Optional
6
+
7
+ from .json import to_jsonable
8
+
9
+
10
+ def error_payload(code: str, message: str, *, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
11
+ payload: Dict[str, Any] = {"error": {"code": code, "message": message}}
12
+ if extra:
13
+ payload["error"].update(extra)
14
+ return payload
15
+
16
+
17
+ def http_exception(
18
+ status_code: int,
19
+ code: str,
20
+ message: str,
21
+ *,
22
+ extra: Optional[Dict[str, Any]] = None,
23
+ headers: Optional[Dict[str, str]] = None,
24
+ ):
25
+ try:
26
+ from fastapi import HTTPException # type: ignore
27
+ except Exception as exc: # pragma: no cover - FastAPI not installed
28
+ raise RuntimeError("fastapi must be installed to raise HTTPException") from exc
29
+
30
+ payload = error_payload(code, message, extra=extra)
31
+ return HTTPException(status_code=status_code, detail=to_jsonable(payload), headers=headers)
32
+
33
+
34
+ def json_error_response(
35
+ status_code: int,
36
+ code: str,
37
+ message: str,
38
+ *,
39
+ extra: Optional[Dict[str, Any]] = None,
40
+ headers: Optional[Dict[str, str]] = None,
41
+ ):
42
+ try:
43
+ from fastapi.responses import JSONResponse # type: ignore
44
+ except Exception as exc: # pragma: no cover - FastAPI not installed
45
+ raise RuntimeError("fastapi must be installed to build JSONResponse") from exc
46
+
47
+ payload = error_payload(code, message, extra=extra)
48
+ return JSONResponse(status_code=status_code, content=to_jsonable(payload), headers=headers)
49
+
synth_ai/task/json.py ADDED
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ """Shared JSON sanitisation helpers for Task Apps."""
4
+
5
+ from collections.abc import Mapping, Sequence
6
+ from dataclasses import is_dataclass, asdict
7
+ from enum import Enum
8
+ from typing import Any
9
+
10
+ try: # numpy is optional at runtime; degrade gracefully if absent
11
+ import numpy as _np # type: ignore
12
+ except Exception: # pragma: no cover - handled at runtime
13
+ _np = None # type: ignore
14
+
15
+
16
+ def _mask_numpy_array(arr: "_np.ndarray") -> str:
17
+ shape = getattr(arr, "shape", None)
18
+ dtype = getattr(arr, "dtype", None)
19
+ return f"<ndarray shape={shape} dtype={dtype}>"
20
+
21
+
22
+ def to_jsonable(value: Any) -> Any:
23
+ """Convert `value` into structures compatible with JSON serialisation.
24
+
25
+ - numpy scalars are converted to their Python counterparts
26
+ - numpy arrays are represented by a compact descriptor string
27
+ - dataclasses, Enums, and pydantic models are unwrapped recursively
28
+ - sets and tuples are converted to lists
29
+ - non-serialisable objects fall back to `repr`
30
+ """
31
+
32
+ if value is None or isinstance(value, (str, bool, int, float)):
33
+ return value
34
+
35
+ # numpy scalars / arrays
36
+ if _np is not None:
37
+ if isinstance(value, (_np.integer,)):
38
+ return int(value)
39
+ if isinstance(value, (_np.floating,)):
40
+ return float(value)
41
+ if isinstance(value, (_np.bool_,)):
42
+ return bool(value)
43
+ if isinstance(value, (_np.ndarray,)):
44
+ return _mask_numpy_array(value)
45
+
46
+ if isinstance(value, Enum):
47
+ return to_jsonable(value.value)
48
+
49
+ if is_dataclass(value):
50
+ return to_jsonable(asdict(value))
51
+
52
+ # pydantic BaseModel / attrs objects
53
+ for attr in ("model_dump", "dict", "to_dict", "to_json"):
54
+ if hasattr(value, attr) and callable(getattr(value, attr, None)):
55
+ try:
56
+ dumped = getattr(value, attr)() # type: ignore[misc]
57
+ except TypeError:
58
+ dumped = getattr(value, attr)(exclude_none=False) # pragma: no cover
59
+ return to_jsonable(dumped)
60
+
61
+ if isinstance(value, Mapping):
62
+ return {str(k): to_jsonable(v) for k, v in value.items()}
63
+
64
+ if isinstance(value, (set, tuple)):
65
+ return [to_jsonable(v) for v in value]
66
+
67
+ if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
68
+ return [to_jsonable(v) for v in value]
69
+
70
+ if isinstance(value, (bytes, bytearray)):
71
+ return f"<bytes len={len(value)}>"
72
+
73
+ if hasattr(value, "__dict__"):
74
+ return to_jsonable(vars(value))
75
+
76
+ return repr(value)
77
+