synth-ai 0.2.8.dev12__py3-none-any.whl → 0.2.8.dev13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (36) hide show
  1. synth_ai/api/train/__init__.py +5 -0
  2. synth_ai/api/train/builders.py +165 -0
  3. synth_ai/api/train/cli.py +429 -0
  4. synth_ai/api/train/config_finder.py +120 -0
  5. synth_ai/api/train/env_resolver.py +302 -0
  6. synth_ai/api/train/pollers.py +66 -0
  7. synth_ai/api/train/task_app.py +128 -0
  8. synth_ai/api/train/utils.py +232 -0
  9. synth_ai/cli/__init__.py +23 -0
  10. synth_ai/cli/rl_demo.py +2 -2
  11. synth_ai/cli/root.py +2 -1
  12. synth_ai/cli/task_apps.py +520 -0
  13. synth_ai/task/__init__.py +94 -1
  14. synth_ai/task/apps/__init__.py +88 -0
  15. synth_ai/task/apps/grpo_crafter.py +438 -0
  16. synth_ai/task/apps/math_single_step.py +852 -0
  17. synth_ai/task/auth.py +132 -0
  18. synth_ai/task/client.py +148 -0
  19. synth_ai/task/contracts.py +29 -14
  20. synth_ai/task/datasets.py +105 -0
  21. synth_ai/task/errors.py +49 -0
  22. synth_ai/task/json.py +77 -0
  23. synth_ai/task/proxy.py +258 -0
  24. synth_ai/task/rubrics.py +212 -0
  25. synth_ai/task/server.py +398 -0
  26. synth_ai/task/tracing_utils.py +79 -0
  27. synth_ai/task/vendors.py +61 -0
  28. synth_ai/tracing_v3/session_tracer.py +13 -5
  29. synth_ai/tracing_v3/storage/base.py +10 -12
  30. synth_ai/tracing_v3/turso/manager.py +20 -6
  31. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/METADATA +3 -2
  32. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/RECORD +36 -14
  33. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/WHEEL +0 -0
  34. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/entry_points.txt +0 -0
  35. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/licenses/LICENSE +0 -0
  36. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/top_level.txt +0 -0
synth_ai/task/auth.py ADDED
@@ -0,0 +1,132 @@
1
+ from __future__ import annotations
2
+
3
+ """Authentication helpers shared by Task Apps."""
4
+
5
+ import os
6
+ from typing import Iterable, Optional, Any
7
+
8
+ from .errors import http_exception
9
+
10
+ _API_KEY_ENV = "ENVIRONMENT_API_KEY"
11
+ _DEV_API_KEY_ENVS = ("dev_environment_api_key", "DEV_ENVIRONMENT_API_KEY")
12
+ _API_KEY_HEADER = "x-api-key"
13
+ _API_KEYS_HEADER = "x-api-keys"
14
+ _AUTH_HEADER = "authorization"
15
+
16
+
17
+ def _mask(value: str, *, prefix: int = 4) -> str:
18
+ if not value:
19
+ return "<empty>"
20
+ visible = value[:prefix]
21
+ return f"{visible}{'…' if len(value) > prefix else ''}"
22
+
23
+
24
+ def normalize_environment_api_key() -> Optional[str]:
25
+ """Ensure `ENVIRONMENT_API_KEY` is populated from dev fallbacks.
26
+
27
+ Returns the resolved key (if any) so callers can branch on configuration.
28
+ """
29
+
30
+ key = os.getenv(_API_KEY_ENV)
31
+ if key:
32
+ return key
33
+ for env in _DEV_API_KEY_ENVS:
34
+ candidate = os.getenv(env)
35
+ if candidate:
36
+ os.environ[_API_KEY_ENV] = candidate
37
+ print(
38
+ f"[task:auth] {_API_KEY_ENV} set from {env} (prefix={_mask(candidate)})",
39
+ flush=True,
40
+ )
41
+ return candidate
42
+ return None
43
+
44
+
45
+ def _header_values(request: Any, header: str) -> Iterable[str]:
46
+ header_lower = header.lower()
47
+ if request is None:
48
+ return []
49
+ headers = getattr(request, "headers", None)
50
+ if headers:
51
+ raw = headers.get(header) or headers.get(header_lower)
52
+ if raw is not None:
53
+ return [raw]
54
+ if isinstance(request, dict):
55
+ raw = request.get(header) or request.get(header_lower)
56
+ if raw is not None:
57
+ return [raw]
58
+ # Support passing explicit header dict via keyword arg on FastAPI route handlers
59
+ for attr in ("headers", "state"):
60
+ maybe = getattr(request, attr, None)
61
+ if isinstance(maybe, dict):
62
+ raw = maybe.get(header) or maybe.get(header_lower)
63
+ if raw is not None:
64
+ return [raw]
65
+ return []
66
+
67
+
68
+ def _split_csv(values: Iterable[str]) -> list[str]:
69
+ seen: list[str] = []
70
+ for v in values:
71
+ if not isinstance(v, str):
72
+ continue
73
+ for part in v.split(","):
74
+ trimmed = part.strip()
75
+ if trimmed:
76
+ seen.append(trimmed)
77
+ return seen
78
+
79
+
80
+ def is_api_key_header_authorized(request: Any) -> bool:
81
+ """Return True if `request` carries an authorised API key header."""
82
+
83
+ expected = normalize_environment_api_key()
84
+ if not expected:
85
+ return False
86
+ single = list(_header_values(request, _API_KEY_HEADER))
87
+ multi = list(_header_values(request, _API_KEYS_HEADER))
88
+ auths = list(_header_values(request, _AUTH_HEADER))
89
+ bearer: list[str] = []
90
+ for a in auths:
91
+ if isinstance(a, str) and a.lower().startswith("bearer "):
92
+ bearer.append(a.split(" ", 1)[1].strip())
93
+ candidates = _split_csv(single + multi + bearer)
94
+ return any(candidate == expected for candidate in candidates)
95
+
96
+
97
+ def require_api_key_dependency(request: Any) -> None:
98
+ """FastAPI dependency enforcing Task App authentication headers."""
99
+
100
+ expected = normalize_environment_api_key()
101
+ if not expected:
102
+ raise http_exception(503, "missing_environment_api_key", "ENVIRONMENT_API_KEY is not configured")
103
+ # Build candidate list for verbose diagnostics
104
+ single = list(_header_values(request, _API_KEY_HEADER))
105
+ multi = list(_header_values(request, _API_KEYS_HEADER))
106
+ auths = list(_header_values(request, _AUTH_HEADER))
107
+ bearer: list[str] = []
108
+ for a in auths:
109
+ if isinstance(a, str) and a.lower().startswith("bearer "):
110
+ bearer.append(a.split(" ", 1)[1].strip())
111
+ candidates = _split_csv(single + multi + bearer)
112
+ if expected not in candidates:
113
+ try:
114
+ print({
115
+ "task_auth_failed": True,
116
+ "expected_first15": expected[:15],
117
+ "expected_len": len(expected),
118
+ "got_first15": [c[:15] for c in candidates],
119
+ "got_lens": [len(c) for c in candidates],
120
+ "have_x_api_key": bool(single),
121
+ "have_x_api_keys": bool(multi),
122
+ "have_authorization": bool(auths),
123
+ }, flush=True)
124
+ except Exception:
125
+ pass
126
+ # Use 400 to make failures unmistakable during preflight
127
+ raise http_exception(400, "unauthorised", "API key missing or invalid", extra={
128
+ "expected_first15": expected[:15],
129
+ "expected_len": len(expected),
130
+ "got_first15": [c[:15] for c in candidates],
131
+ "got_lens": [len(c) for c in candidates],
132
+ })
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ """Async HTTP client for interacting with Task Apps."""
4
+
5
+ import asyncio
6
+ from typing import Any, Dict, Iterable, List, Optional
7
+
8
+ import httpx
9
+ from pydantic import BaseModel
10
+
11
+ from .contracts import RolloutRequest, RolloutResponse, TaskInfo
12
+ from .json import to_jsonable
13
+
14
+
15
+ def _prepare_payload(payload: Any) -> Any:
16
+ if payload is None:
17
+ return None
18
+ if isinstance(payload, BaseModel):
19
+ return payload.model_dump(mode="json", by_alias=True)
20
+ return to_jsonable(payload)
21
+
22
+
23
+ class TaskAppClient:
24
+ def __init__(
25
+ self,
26
+ base_url: str,
27
+ api_key: str | None = None,
28
+ *,
29
+ timeout: float = 600.0,
30
+ retries: int = 3,
31
+ ) -> None:
32
+ self.base_url = base_url.rstrip("/")
33
+ self.api_key = api_key
34
+ self.timeout = timeout
35
+ self.retries = max(1, retries)
36
+ self._client: httpx.AsyncClient | None = None
37
+ self.env = _TaskAppEnvironmentClient(self)
38
+
39
+ async def __aenter__(self) -> "TaskAppClient":
40
+ await self._ensure_client()
41
+ return self
42
+
43
+ async def __aexit__(self, exc_type, exc, tb) -> None:
44
+ await self.aclose()
45
+
46
+ async def _ensure_client(self) -> httpx.AsyncClient:
47
+ if self._client is None:
48
+ self._client = httpx.AsyncClient(
49
+ base_url=self.base_url,
50
+ timeout=httpx.Timeout(self.timeout),
51
+ follow_redirects=True,
52
+ )
53
+ return self._client
54
+
55
+ def _headers(self) -> Dict[str, str]:
56
+ headers: Dict[str, str] = {}
57
+ if self.api_key:
58
+ headers["X-API-Key"] = self.api_key
59
+ return headers
60
+
61
+ async def aclose(self) -> None:
62
+ if self._client is not None:
63
+ await self._client.aclose()
64
+ self._client = None
65
+
66
+ async def _request(
67
+ self,
68
+ method: str,
69
+ path: str,
70
+ *,
71
+ params: Optional[Iterable[tuple[str, Any]] | Dict[str, Any]] = None,
72
+ json_payload: Any = None,
73
+ ) -> httpx.Response:
74
+ client = await self._ensure_client()
75
+ payload = _prepare_payload(json_payload)
76
+ headers = self._headers()
77
+ last_exc: Exception | None = None
78
+ for attempt in range(self.retries):
79
+ try:
80
+ response = await client.request(
81
+ method,
82
+ path,
83
+ headers=headers,
84
+ params=params,
85
+ json=payload,
86
+ )
87
+ response.raise_for_status()
88
+ return response
89
+ except httpx.HTTPStatusError as exc:
90
+ if 500 <= exc.response.status_code < 600 and attempt + 1 < self.retries:
91
+ await asyncio.sleep(0.1 * (attempt + 1))
92
+ last_exc = exc
93
+ continue
94
+ raise
95
+ except httpx.HTTPError as exc:
96
+ last_exc = exc
97
+ if attempt + 1 >= self.retries:
98
+ raise
99
+ await asyncio.sleep(0.1 * (attempt + 1))
100
+ if last_exc: # pragma: no cover - defensive
101
+ raise last_exc
102
+ raise RuntimeError("Unreachable code in TaskAppClient._request")
103
+
104
+ async def health(self) -> Dict[str, Any]:
105
+ response = await self._request("GET", "/health")
106
+ return response.json()
107
+
108
+ async def info(self) -> Dict[str, Any]:
109
+ response = await self._request("GET", "/info")
110
+ return response.json()
111
+
112
+ async def task_info(self, seeds: list[int] | None = None) -> TaskInfo | list[TaskInfo]:
113
+ params: Optional[List[tuple[str, Any]]] = None
114
+ if seeds:
115
+ params = [("seed", seed) for seed in seeds]
116
+ response = await self._request("GET", "/task_info", params=params)
117
+ data = response.json()
118
+ if isinstance(data, list):
119
+ return [TaskInfo.model_validate(item) for item in data]
120
+ return TaskInfo.model_validate(data)
121
+
122
+ async def rollout(self, request: RolloutRequest) -> RolloutResponse:
123
+ response = await self._request("POST", "/rollout", json_payload=request)
124
+ data = response.json()
125
+ return RolloutResponse.model_validate(data)
126
+
127
+
128
+ class _TaskAppEnvironmentClient:
129
+ def __init__(self, client: TaskAppClient) -> None:
130
+ self._client = client
131
+
132
+ async def initialize(self, env_name: str, payload: Dict[str, Any]) -> Dict[str, Any]:
133
+ response = await self._client._request(
134
+ "POST", f"/env/{env_name}/initialize", json_payload=payload
135
+ )
136
+ return response.json()
137
+
138
+ async def step(self, env_name: str, payload: Dict[str, Any]) -> Dict[str, Any]:
139
+ response = await self._client._request(
140
+ "POST", f"/env/{env_name}/step", json_payload=payload
141
+ )
142
+ return response.json()
143
+
144
+ async def terminate(self, env_name: str, payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
145
+ response = await self._client._request(
146
+ "POST", f"/env/{env_name}/terminate", json_payload=payload or {}
147
+ )
148
+ return response.json()
@@ -1,31 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Optional, Any, Dict, List
5
- from pydantic import BaseModel
4
+ from typing import Optional, Any, Dict, List, Literal
5
+ from pydantic import BaseModel, Field
6
6
 
7
7
 
8
8
  @dataclass(frozen=True)
9
9
  class TaskAppEndpoints:
10
10
  """Canonical Task App endpoint shapes used by RL trainers.
11
11
 
12
- The Task App is an HTTP service (often deployed on Modal) that exposes:
13
- - Health: GET /health
14
- Requires header X-API-Key (when ENVIRONMENT_API_KEY is configured)
15
- Returns { healthy: true }
16
- - Environment lifecycle:
17
- • POST /env/{env_name}/initialize → { env_id, observation }
18
- • POST /env/{env_name}/step → { observation, reward, done, info }
19
- • POST /env/{env_name}/terminate → { ok: true }
20
- - Rollout (optional, unified schema):
21
- • POST /rollout → { run_id, trajectories[], metrics, ... }
22
- - Proxy (optional):
23
- • POST /proxy/v1/chat/completions (for direct OpenAI calls from Task App)
12
+ Task Apps run as lightweight HTTP services (often on Modal) that expose a
13
+ consistent set of endpoints for health, metadata, environment lifecycle,
14
+ rollouts, and optional proxy access to vendor models. The endpoint strings
15
+ defined here act as defaults and documentation for clients.
24
16
  """
25
17
 
18
+ root: str = "/"
26
19
  health: str = "/health"
20
+ info: str = "/info"
21
+ task_info: str = "/task_info"
27
22
  rollout: str = "/rollout"
28
23
  proxy_chat_completions: str = "/proxy/v1/chat/completions"
24
+ proxy_groq_chat_completions: str = "/proxy/groq/v1/chat/completions"
29
25
  env_initialize: str = "/env/{env_name}/initialize"
30
26
  env_step: str = "/env/{env_name}/step"
31
27
  env_terminate: str = "/env/{env_name}/terminate"
@@ -67,6 +63,8 @@ class RolloutRecordConfig(BaseModel):
67
63
  trajectories: bool = True
68
64
  logprobs: bool = False
69
65
  value: bool = False
66
+ return_trace: bool = False
67
+ trace_format: Literal["compact", "full"] = "compact"
70
68
 
71
69
 
72
70
  class RolloutSafetyConfig(BaseModel):
@@ -108,6 +106,9 @@ class RolloutMetrics(BaseModel):
108
106
  mean_return: float
109
107
  num_steps: int
110
108
  num_episodes: int = 0
109
+ outcome_score: Optional[float] = None
110
+ events_score: Optional[float] = None
111
+ details: Dict[str, Any] = Field(default_factory=dict)
111
112
 
112
113
 
113
114
  class RolloutResponse(BaseModel):
@@ -117,4 +118,18 @@ class RolloutResponse(BaseModel):
117
118
  metrics: RolloutMetrics
118
119
  aborted: bool = False
119
120
  ops_executed: int = 0
121
+ trace: Dict[str, Any] | None = None
120
122
 
123
+
124
+ class TaskInfo(BaseModel):
125
+ """Static metadata describing the capabilities of a Task App task."""
126
+
127
+ task: Dict[str, Any]
128
+ environments: List[str]
129
+ action_space: Dict[str, Any]
130
+ observation: Dict[str, Any]
131
+ dataset: Dict[str, Any]
132
+ rubric: Dict[str, Any]
133
+ inference: Dict[str, Any]
134
+ capabilities: Dict[str, Any]
135
+ limits: Dict[str, Any]
@@ -0,0 +1,105 @@
1
+ from __future__ import annotations
2
+
3
+ """Dataset registry and helpers shared by Task Apps."""
4
+
5
+ from typing import Any, Callable, Dict, Hashable, Tuple
6
+
7
+ from pydantic import BaseModel, Field, field_validator
8
+
9
+
10
+ class TaskDatasetSpec(BaseModel):
11
+ """Declarative metadata describing a dataset that a Task App exposes."""
12
+
13
+ id: str
14
+ name: str
15
+ version: str | None = None
16
+ splits: list[str] = Field(default_factory=list)
17
+ default_split: str | None = None
18
+ cardinality: int | None = None
19
+ description: str | None = None
20
+
21
+ @field_validator("default_split")
22
+ @classmethod
23
+ def _validate_default_split(cls, value: str | None, info):
24
+ values = info.data if hasattr(info, "data") else {} # type: ignore[attr-defined]
25
+ if value and value not in (values.get("splits") or []):
26
+ raise ValueError("default_split must be one of splits when provided")
27
+ return value
28
+
29
+
30
+ RegistryLoader = Callable[[TaskDatasetSpec], Any]
31
+
32
+
33
+ class TaskDatasetRegistry:
34
+ """Lightweight registry mapping dataset specs to loader callables."""
35
+
36
+ def __init__(self) -> None:
37
+ self._entries: Dict[str, Tuple[TaskDatasetSpec, RegistryLoader, bool]] = {}
38
+ self._cache: Dict[Hashable, Any] = {}
39
+
40
+ def register(self, spec: TaskDatasetSpec, loader: RegistryLoader, *, cache: bool = True) -> None:
41
+ """Register a dataset loader and its metadata."""
42
+
43
+ self._entries[spec.id] = (spec, loader, cache)
44
+
45
+ def describe(self, dataset_id: str) -> TaskDatasetSpec:
46
+ if dataset_id not in self._entries:
47
+ raise KeyError(f"Dataset not registered: {dataset_id}")
48
+ spec, _, _ = self._entries[dataset_id]
49
+ return spec
50
+
51
+ def list(self) -> list[TaskDatasetSpec]:
52
+ return [entry[0] for entry in self._entries.values()]
53
+
54
+ def get(self, spec: TaskDatasetSpec | str) -> Any:
55
+ """Return dataset materialisation (with optional caching)."""
56
+
57
+ if isinstance(spec, str):
58
+ if spec not in self._entries:
59
+ raise KeyError(f"Dataset not registered: {spec}")
60
+ base_spec, loader, cache_enabled = self._entries[spec]
61
+ effective_spec = base_spec
62
+ else:
63
+ if spec.id not in self._entries:
64
+ raise KeyError(f"Dataset not registered: {spec.id}")
65
+ base_spec, loader, cache_enabled = self._entries[spec.id]
66
+ effective_spec = base_spec.model_copy(update=spec.model_dump(exclude_unset=True))
67
+
68
+ cache_key: Hashable = (
69
+ effective_spec.id,
70
+ effective_spec.version,
71
+ effective_spec.default_split,
72
+ )
73
+ if cache_enabled:
74
+ if cache_key not in self._cache:
75
+ self._cache[cache_key] = loader(effective_spec)
76
+ return self._cache[cache_key]
77
+ return loader(effective_spec)
78
+
79
+ @staticmethod
80
+ def ensure_split(spec: TaskDatasetSpec, split: str | None) -> str:
81
+ """Validate that `split` exists on the spec; return a concrete split."""
82
+
83
+ if not spec.splits:
84
+ return split or spec.default_split or "default"
85
+ if split is None:
86
+ if spec.default_split:
87
+ return spec.default_split
88
+ raise ValueError(f"split must be provided for dataset {spec.id}")
89
+ if split not in spec.splits:
90
+ raise ValueError(f"Unknown split '{split}' for dataset {spec.id}")
91
+ return split
92
+
93
+ @staticmethod
94
+ def normalise_seed(seed: Any, *, cardinality: int | None = None) -> int:
95
+ """Normalise arbitrary seed input into a bounded non-negative integer."""
96
+
97
+ try:
98
+ value = int(seed)
99
+ except Exception as exc: # pragma: no cover - defensive
100
+ raise ValueError(f"Seed must be convertible to int (got {seed!r})") from exc
101
+ if value < 0:
102
+ value = abs(value)
103
+ if cardinality and cardinality > 0:
104
+ value = value % cardinality
105
+ return value
@@ -0,0 +1,49 @@
1
+ from __future__ import annotations
2
+
3
+ """Error helpers used across Task App implementations."""
4
+
5
+ from typing import Any, Dict, Optional
6
+
7
+ from .json import to_jsonable
8
+
9
+
10
+ def error_payload(code: str, message: str, *, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
11
+ payload: Dict[str, Any] = {"error": {"code": code, "message": message}}
12
+ if extra:
13
+ payload["error"].update(extra)
14
+ return payload
15
+
16
+
17
+ def http_exception(
18
+ status_code: int,
19
+ code: str,
20
+ message: str,
21
+ *,
22
+ extra: Optional[Dict[str, Any]] = None,
23
+ headers: Optional[Dict[str, str]] = None,
24
+ ):
25
+ try:
26
+ from fastapi import HTTPException # type: ignore
27
+ except Exception as exc: # pragma: no cover - FastAPI not installed
28
+ raise RuntimeError("fastapi must be installed to raise HTTPException") from exc
29
+
30
+ payload = error_payload(code, message, extra=extra)
31
+ return HTTPException(status_code=status_code, detail=to_jsonable(payload), headers=headers)
32
+
33
+
34
+ def json_error_response(
35
+ status_code: int,
36
+ code: str,
37
+ message: str,
38
+ *,
39
+ extra: Optional[Dict[str, Any]] = None,
40
+ headers: Optional[Dict[str, str]] = None,
41
+ ):
42
+ try:
43
+ from fastapi.responses import JSONResponse # type: ignore
44
+ except Exception as exc: # pragma: no cover - FastAPI not installed
45
+ raise RuntimeError("fastapi must be installed to build JSONResponse") from exc
46
+
47
+ payload = error_payload(code, message, extra=extra)
48
+ return JSONResponse(status_code=status_code, content=to_jsonable(payload), headers=headers)
49
+
synth_ai/task/json.py ADDED
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ """Shared JSON sanitisation helpers for Task Apps."""
4
+
5
+ from collections.abc import Mapping, Sequence
6
+ from dataclasses import is_dataclass, asdict
7
+ from enum import Enum
8
+ from typing import Any
9
+
10
+ try: # numpy is optional at runtime; degrade gracefully if absent
11
+ import numpy as _np # type: ignore
12
+ except Exception: # pragma: no cover - handled at runtime
13
+ _np = None # type: ignore
14
+
15
+
16
+ def _mask_numpy_array(arr: "_np.ndarray") -> str:
17
+ shape = getattr(arr, "shape", None)
18
+ dtype = getattr(arr, "dtype", None)
19
+ return f"<ndarray shape={shape} dtype={dtype}>"
20
+
21
+
22
+ def to_jsonable(value: Any) -> Any:
23
+ """Convert `value` into structures compatible with JSON serialisation.
24
+
25
+ - numpy scalars are converted to their Python counterparts
26
+ - numpy arrays are represented by a compact descriptor string
27
+ - dataclasses, Enums, and pydantic models are unwrapped recursively
28
+ - sets and tuples are converted to lists
29
+ - non-serialisable objects fall back to `repr`
30
+ """
31
+
32
+ if value is None or isinstance(value, (str, bool, int, float)):
33
+ return value
34
+
35
+ # numpy scalars / arrays
36
+ if _np is not None:
37
+ if isinstance(value, (_np.integer,)):
38
+ return int(value)
39
+ if isinstance(value, (_np.floating,)):
40
+ return float(value)
41
+ if isinstance(value, (_np.bool_,)):
42
+ return bool(value)
43
+ if isinstance(value, (_np.ndarray,)):
44
+ return _mask_numpy_array(value)
45
+
46
+ if isinstance(value, Enum):
47
+ return to_jsonable(value.value)
48
+
49
+ if is_dataclass(value):
50
+ return to_jsonable(asdict(value))
51
+
52
+ # pydantic BaseModel / attrs objects
53
+ for attr in ("model_dump", "dict", "to_dict", "to_json"):
54
+ if hasattr(value, attr) and callable(getattr(value, attr, None)):
55
+ try:
56
+ dumped = getattr(value, attr)() # type: ignore[misc]
57
+ except TypeError:
58
+ dumped = getattr(value, attr)(exclude_none=False) # pragma: no cover
59
+ return to_jsonable(dumped)
60
+
61
+ if isinstance(value, Mapping):
62
+ return {str(k): to_jsonable(v) for k, v in value.items()}
63
+
64
+ if isinstance(value, (set, tuple)):
65
+ return [to_jsonable(v) for v in value]
66
+
67
+ if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
68
+ return [to_jsonable(v) for v in value]
69
+
70
+ if isinstance(value, (bytes, bytearray)):
71
+ return f"<bytes len={len(value)}>"
72
+
73
+ if hasattr(value, "__dict__"):
74
+ return to_jsonable(vars(value))
75
+
76
+ return repr(value)
77
+