synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,519 @@
|
|
|
1
|
+
"""Shared helper utilities for LocalAPI task apps."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import contextlib
|
|
6
|
+
import inspect
|
|
7
|
+
import os
|
|
8
|
+
import socket
|
|
9
|
+
from collections.abc import Callable, Sequence
|
|
10
|
+
from typing import Any
|
|
11
|
+
from urllib.parse import urlparse, urlunparse
|
|
12
|
+
|
|
13
|
+
from fastapi import FastAPI, HTTPException, Request
|
|
14
|
+
from fastapi.responses import JSONResponse
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def normalize_chat_completion_url(url: str) -> str:
|
|
18
|
+
"""Normalize inference URL to include /chat/completions path."""
|
|
19
|
+
u = (url or "").rstrip("/")
|
|
20
|
+
if not u:
|
|
21
|
+
return "/chat/completions"
|
|
22
|
+
|
|
23
|
+
parsed = urlparse(u)
|
|
24
|
+
path = parsed.path.rstrip("/")
|
|
25
|
+
query = parsed.query
|
|
26
|
+
fragment = parsed.fragment
|
|
27
|
+
|
|
28
|
+
if path.endswith("/v1/chat/completions") or path.endswith("/chat/completions"):
|
|
29
|
+
return u
|
|
30
|
+
|
|
31
|
+
if "/v1/" in path and not path.endswith("/v1"):
|
|
32
|
+
new_path = f"{path}/chat/completions"
|
|
33
|
+
return urlunparse((parsed.scheme, parsed.netloc, new_path, parsed.params, query, fragment))
|
|
34
|
+
|
|
35
|
+
if path.endswith("/v1"):
|
|
36
|
+
new_path = f"{path}/chat/completions"
|
|
37
|
+
elif path.endswith("/completions"):
|
|
38
|
+
new_path = path.rsplit("/", 1)[0] + "/chat/completions"
|
|
39
|
+
else:
|
|
40
|
+
new_path = f"{path}/v1/chat/completions" if path else "/v1/chat/completions"
|
|
41
|
+
|
|
42
|
+
return urlunparse((parsed.scheme, parsed.netloc, new_path, parsed.params, query, fragment))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_default_max_completion_tokens(model_name: str) -> int:
|
|
46
|
+
"""Get default max_completion_tokens based on model name."""
|
|
47
|
+
model_lower = model_name.lower()
|
|
48
|
+
if "gpt-5" in model_lower or "gpt5" in model_lower:
|
|
49
|
+
return 2048
|
|
50
|
+
if "gpt-4" in model_lower or "gpt4" in model_lower:
|
|
51
|
+
return 4096
|
|
52
|
+
if "o1" in model_lower or "o3" in model_lower:
|
|
53
|
+
return 16384
|
|
54
|
+
if "claude" in model_lower:
|
|
55
|
+
return 4096
|
|
56
|
+
return 512
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_current_module_source() -> str | None:
|
|
60
|
+
"""Extract source code for the caller's module using inspect."""
|
|
61
|
+
frame = inspect.currentframe()
|
|
62
|
+
try:
|
|
63
|
+
if frame is None:
|
|
64
|
+
return None
|
|
65
|
+
caller_frame = frame.f_back
|
|
66
|
+
if caller_frame is None:
|
|
67
|
+
return None
|
|
68
|
+
module = inspect.getmodule(caller_frame)
|
|
69
|
+
if module is None:
|
|
70
|
+
return None
|
|
71
|
+
try:
|
|
72
|
+
return inspect.getsource(module)
|
|
73
|
+
except (OSError, TypeError, IOError):
|
|
74
|
+
return None
|
|
75
|
+
finally:
|
|
76
|
+
del frame
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def preload_dataset_splits(dataset: Any, splits: Sequence[str], app_name: str) -> None:
|
|
80
|
+
"""Preload dataset splits with standardized logging."""
|
|
81
|
+
print(f"[{app_name}] Preloading dataset splits...", flush=True)
|
|
82
|
+
try:
|
|
83
|
+
dataset.ensure_ready(splits)
|
|
84
|
+
sizes = []
|
|
85
|
+
with contextlib.suppress(Exception):
|
|
86
|
+
sizes = [dataset.size(split) for split in splits]
|
|
87
|
+
if sizes:
|
|
88
|
+
print(f"[{app_name}] Dataset preloaded successfully: {sizes} examples", flush=True)
|
|
89
|
+
else:
|
|
90
|
+
print(f"[{app_name}] Dataset preloaded successfully", flush=True)
|
|
91
|
+
except Exception as exc:
|
|
92
|
+
print(f"[{app_name}] WARNING: Dataset preload failed: {exc}", flush=True)
|
|
93
|
+
import traceback
|
|
94
|
+
|
|
95
|
+
traceback.print_exc()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def create_http_client_hooks(
|
|
99
|
+
timeout: float = 30.0,
|
|
100
|
+
*,
|
|
101
|
+
log_prefix: str | None = None,
|
|
102
|
+
aiohttp_connector_kwargs: dict[str, Any] | None = None,
|
|
103
|
+
httpx_limits: Any | None = None,
|
|
104
|
+
) -> tuple[Callable[[Any], Any], Callable[[Any], Any]]:
|
|
105
|
+
"""Return (startup_hook, shutdown_hook) for HTTP client lifecycle."""
|
|
106
|
+
connector_kwargs = {
|
|
107
|
+
"limit": 10,
|
|
108
|
+
"limit_per_host": 5,
|
|
109
|
+
"ttl_dns_cache": 300,
|
|
110
|
+
"use_dns_cache": True,
|
|
111
|
+
}
|
|
112
|
+
if aiohttp_connector_kwargs:
|
|
113
|
+
connector_kwargs.update(aiohttp_connector_kwargs)
|
|
114
|
+
|
|
115
|
+
def _log(message: str) -> None:
|
|
116
|
+
if log_prefix:
|
|
117
|
+
print(f"[{log_prefix}] {message}", flush=True)
|
|
118
|
+
|
|
119
|
+
async def startup_http_client(app: Any) -> None:
|
|
120
|
+
try:
|
|
121
|
+
import aiohttp
|
|
122
|
+
|
|
123
|
+
timeout_cfg = aiohttp.ClientTimeout(total=timeout)
|
|
124
|
+
connector = aiohttp.TCPConnector(**connector_kwargs)
|
|
125
|
+
app.state.http_client = aiohttp.ClientSession(timeout=timeout_cfg, connector=connector)
|
|
126
|
+
_log("Created app-level aiohttp client session singleton")
|
|
127
|
+
except ImportError:
|
|
128
|
+
try:
|
|
129
|
+
import httpx
|
|
130
|
+
|
|
131
|
+
limits = httpx_limits or httpx.Limits(max_keepalive_connections=5, max_connections=10)
|
|
132
|
+
app.state.http_client = httpx.AsyncClient(timeout=timeout, limits=limits)
|
|
133
|
+
_log("Created app-level httpx client singleton (fallback)")
|
|
134
|
+
except Exception as exc:
|
|
135
|
+
_log(f"WARNING: Failed to create http client: {exc}")
|
|
136
|
+
app.state.http_client = None
|
|
137
|
+
except Exception as exc:
|
|
138
|
+
_log(f"WARNING: Failed to create aiohttp client: {exc}")
|
|
139
|
+
app.state.http_client = None
|
|
140
|
+
|
|
141
|
+
async def shutdown_http_client(app: Any) -> None:
|
|
142
|
+
http_client = getattr(app.state, "http_client", None)
|
|
143
|
+
if http_client is not None:
|
|
144
|
+
try:
|
|
145
|
+
if hasattr(http_client, "close"):
|
|
146
|
+
await http_client.close()
|
|
147
|
+
elif hasattr(http_client, "aclose"):
|
|
148
|
+
await http_client.aclose()
|
|
149
|
+
_log("Closed app-level http client")
|
|
150
|
+
except Exception as exc:
|
|
151
|
+
_log(f"WARNING: Error closing http client: {exc}")
|
|
152
|
+
|
|
153
|
+
return startup_http_client, shutdown_http_client
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def extract_api_key(
|
|
157
|
+
request: Request,
|
|
158
|
+
policy_config: dict[str, Any],
|
|
159
|
+
default_env_keys: dict[str, str] | None = None,
|
|
160
|
+
) -> str | None:
|
|
161
|
+
"""Extract API key from request headers or environment based on inference URL."""
|
|
162
|
+
default_env_keys = default_env_keys or {
|
|
163
|
+
"api.groq.com": "GROQ_API_KEY",
|
|
164
|
+
"api.openai.com": "OPENAI_API_KEY",
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
inference_url_raw = policy_config.get("inference_url")
|
|
168
|
+
api_base_raw = policy_config.get("api_base")
|
|
169
|
+
base_url_raw = policy_config.get("base_url")
|
|
170
|
+
route_base = (
|
|
171
|
+
(str(inference_url_raw).strip() if inference_url_raw else "")
|
|
172
|
+
or (str(api_base_raw).strip() if api_base_raw else "")
|
|
173
|
+
or (str(base_url_raw).strip() if base_url_raw else "")
|
|
174
|
+
)
|
|
175
|
+
lowered = route_base.lower()
|
|
176
|
+
for host, env_var in default_env_keys.items():
|
|
177
|
+
if host in lowered:
|
|
178
|
+
return os.getenv(env_var)
|
|
179
|
+
|
|
180
|
+
api_key = request.headers.get("X-API-Key") or request.headers.get("x-api-key")
|
|
181
|
+
if api_key:
|
|
182
|
+
return api_key
|
|
183
|
+
auth_header = request.headers.get("Authorization") or request.headers.get("authorization")
|
|
184
|
+
if auth_header:
|
|
185
|
+
return auth_header.replace("Bearer ", "").strip()
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def parse_tool_calls_from_response(
|
|
190
|
+
response_json: dict[str, Any],
|
|
191
|
+
expected_tool_name: str | None = None,
|
|
192
|
+
) -> list[dict[str, Any]]:
|
|
193
|
+
"""Parse tool calls from chat completion response."""
|
|
194
|
+
if not isinstance(response_json, dict):
|
|
195
|
+
return []
|
|
196
|
+
choices = response_json.get("choices") or []
|
|
197
|
+
if not choices:
|
|
198
|
+
return []
|
|
199
|
+
message = (choices[0] or {}).get("message", {}) if choices else {}
|
|
200
|
+
tool_calls_raw = message.get("tool_calls", []) or []
|
|
201
|
+
tool_calls: list[dict[str, Any]] = []
|
|
202
|
+
for call in tool_calls_raw:
|
|
203
|
+
function_block = (call or {}).get("function", {}) or {}
|
|
204
|
+
name = function_block.get("name", "")
|
|
205
|
+
if expected_tool_name and name and name != expected_tool_name:
|
|
206
|
+
raise ValueError(f"Unexpected tool name: {name}")
|
|
207
|
+
tool_calls.append(
|
|
208
|
+
{
|
|
209
|
+
"id": (call or {}).get("id", ""),
|
|
210
|
+
"type": (call or {}).get("type", "function"),
|
|
211
|
+
"function": {
|
|
212
|
+
"name": name,
|
|
213
|
+
"arguments": function_block.get("arguments", "{}"),
|
|
214
|
+
},
|
|
215
|
+
}
|
|
216
|
+
)
|
|
217
|
+
return tool_calls
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
async def call_chat_completion_api(
|
|
221
|
+
policy_config: dict[str, Any],
|
|
222
|
+
messages: list[dict[str, str]],
|
|
223
|
+
tools: list[dict[str, Any]] | None = None,
|
|
224
|
+
tool_choice: str | None = None,
|
|
225
|
+
api_key: str | None = None,
|
|
226
|
+
http_client: Any | None = None,
|
|
227
|
+
enable_dns_preresolution: bool = True,
|
|
228
|
+
validate_response: bool = True,
|
|
229
|
+
expected_tool_name: str | None = None,
|
|
230
|
+
*,
|
|
231
|
+
default_temperature: float = 0.7,
|
|
232
|
+
log_prefix: str | None = None,
|
|
233
|
+
) -> tuple[str, dict[str, Any], list[dict[str, Any]]]:
|
|
234
|
+
"""Unified chat completion API caller with common LocalAPI logic."""
|
|
235
|
+
missing_fields: list[str] = []
|
|
236
|
+
model_val = policy_config.get("model")
|
|
237
|
+
if not isinstance(model_val, str) or not model_val.strip():
|
|
238
|
+
missing_fields.append("model")
|
|
239
|
+
|
|
240
|
+
inference_url_raw = policy_config.get("inference_url")
|
|
241
|
+
api_base_raw = policy_config.get("api_base")
|
|
242
|
+
base_url_raw = policy_config.get("base_url")
|
|
243
|
+
|
|
244
|
+
if inference_url_raw:
|
|
245
|
+
route_base = str(inference_url_raw).strip()
|
|
246
|
+
if (api_base_raw or base_url_raw) and log_prefix:
|
|
247
|
+
print(
|
|
248
|
+
f"{log_prefix} inference_url is set ({route_base}), ignoring api_base/base_url",
|
|
249
|
+
flush=True,
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
route_base = ((api_base_raw or "").strip()) or ((base_url_raw or "").strip())
|
|
253
|
+
|
|
254
|
+
if not route_base:
|
|
255
|
+
missing_fields.append("inference_url")
|
|
256
|
+
if missing_fields:
|
|
257
|
+
raise HTTPException(
|
|
258
|
+
status_code=400,
|
|
259
|
+
detail="Missing policy fields in TOML [prompt_learning.policy]: "
|
|
260
|
+
+ ", ".join(missing_fields),
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
model = policy_config["model"].strip()
|
|
264
|
+
inference_url = normalize_chat_completion_url(str(route_base))
|
|
265
|
+
temperature = policy_config.get("temperature", default_temperature)
|
|
266
|
+
|
|
267
|
+
if "max_completion_tokens" in policy_config:
|
|
268
|
+
max_tokens = policy_config.get("max_completion_tokens")
|
|
269
|
+
elif "max_tokens" in policy_config:
|
|
270
|
+
max_tokens = policy_config.get("max_tokens")
|
|
271
|
+
else:
|
|
272
|
+
max_tokens = get_default_max_completion_tokens(model)
|
|
273
|
+
|
|
274
|
+
headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
275
|
+
lowered = route_base.lower()
|
|
276
|
+
is_provider_host = ("api.openai.com" in lowered) or ("api.groq.com" in lowered)
|
|
277
|
+
|
|
278
|
+
if api_key:
|
|
279
|
+
if is_provider_host:
|
|
280
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
281
|
+
else:
|
|
282
|
+
headers["X-API-Key"] = api_key
|
|
283
|
+
|
|
284
|
+
payload: dict[str, Any] = {
|
|
285
|
+
"model": model,
|
|
286
|
+
"messages": messages,
|
|
287
|
+
"max_completion_tokens": max_tokens,
|
|
288
|
+
}
|
|
289
|
+
if tools is not None:
|
|
290
|
+
payload["tools"] = tools
|
|
291
|
+
if tool_choice is not None:
|
|
292
|
+
payload["tool_choice"] = tool_choice
|
|
293
|
+
if temperature != 0.0:
|
|
294
|
+
payload["temperature"] = temperature
|
|
295
|
+
|
|
296
|
+
if log_prefix:
|
|
297
|
+
with contextlib.suppress(Exception):
|
|
298
|
+
print(f"{log_prefix} POLICY ROUTE -> {inference_url}", flush=True)
|
|
299
|
+
|
|
300
|
+
if enable_dns_preresolution and not is_provider_host:
|
|
301
|
+
parsed = urlparse(inference_url)
|
|
302
|
+
host = parsed.hostname or ""
|
|
303
|
+
port = parsed.port or (443 if parsed.scheme == "https" else 80)
|
|
304
|
+
with contextlib.suppress(Exception):
|
|
305
|
+
addrinfo = socket.getaddrinfo(host, None, socket.AF_INET)
|
|
306
|
+
ips = sorted({ai[4][0] for ai in addrinfo})
|
|
307
|
+
resolved_ip = ips[0] if ips else None
|
|
308
|
+
if log_prefix:
|
|
309
|
+
print(
|
|
310
|
+
f"{log_prefix} PROXY_DNS resolved {host} -> {resolved_ip} (from {ips})",
|
|
311
|
+
flush=True,
|
|
312
|
+
)
|
|
313
|
+
if resolved_ip and parsed.scheme == "https":
|
|
314
|
+
netloc = f"{resolved_ip}:{port}" if port else resolved_ip
|
|
315
|
+
inference_url = f"{parsed.scheme}://{netloc}{parsed.path}"
|
|
316
|
+
if parsed.query:
|
|
317
|
+
inference_url += f"?{parsed.query}"
|
|
318
|
+
headers["_original_host"] = host
|
|
319
|
+
headers["_use_ip"] = "1"
|
|
320
|
+
headers["Host"] = host
|
|
321
|
+
|
|
322
|
+
if http_client is None:
|
|
323
|
+
raise HTTPException(status_code=500, detail="HTTP client not initialized (should be created at startup)")
|
|
324
|
+
|
|
325
|
+
response_json: dict[str, Any] | None = None
|
|
326
|
+
try:
|
|
327
|
+
is_aiohttp = False
|
|
328
|
+
with contextlib.suppress(Exception):
|
|
329
|
+
import aiohttp
|
|
330
|
+
|
|
331
|
+
is_aiohttp = isinstance(http_client, aiohttp.ClientSession)
|
|
332
|
+
|
|
333
|
+
if is_aiohttp:
|
|
334
|
+
use_ip = headers.pop("_use_ip", None) is not None
|
|
335
|
+
original_host = headers.pop("_original_host", None)
|
|
336
|
+
request_headers = {k: v for k, v in headers.items() if not k.startswith("_")}
|
|
337
|
+
|
|
338
|
+
ssl_setting: Any = None
|
|
339
|
+
if use_ip and original_host:
|
|
340
|
+
import ssl
|
|
341
|
+
|
|
342
|
+
ssl_context = ssl.create_default_context()
|
|
343
|
+
ssl_context.check_hostname = False
|
|
344
|
+
ssl_context.verify_mode = ssl.CERT_NONE
|
|
345
|
+
ssl_setting = ssl_context
|
|
346
|
+
|
|
347
|
+
async with http_client.post(
|
|
348
|
+
inference_url,
|
|
349
|
+
json=payload,
|
|
350
|
+
headers=request_headers,
|
|
351
|
+
ssl=ssl_setting,
|
|
352
|
+
server_hostname=original_host if (use_ip and original_host) else None,
|
|
353
|
+
) as response:
|
|
354
|
+
status_code = response.status
|
|
355
|
+
if status_code != 200:
|
|
356
|
+
try:
|
|
357
|
+
error_json = await response.json()
|
|
358
|
+
error_message = _extract_error_message(error_json)
|
|
359
|
+
raise HTTPException(status_code=status_code, detail=f"Interceptor/provider error: {error_message}")
|
|
360
|
+
except HTTPException:
|
|
361
|
+
raise
|
|
362
|
+
except Exception:
|
|
363
|
+
error_text = (await response.text())[:500]
|
|
364
|
+
raise HTTPException(
|
|
365
|
+
status_code=status_code,
|
|
366
|
+
detail=f"Interceptor/provider returned error: {error_text}",
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
try:
|
|
370
|
+
response_json = await response.json()
|
|
371
|
+
except Exception:
|
|
372
|
+
response_text = await response.text()
|
|
373
|
+
if status_code >= 400:
|
|
374
|
+
raise HTTPException(status_code=status_code, detail=f"HTTP error: {response_text[:200]}")
|
|
375
|
+
response_json = {}
|
|
376
|
+
else:
|
|
377
|
+
response = await http_client.post(inference_url, json=payload, headers=headers)
|
|
378
|
+
status_code = response.status_code
|
|
379
|
+
if status_code != 200:
|
|
380
|
+
try:
|
|
381
|
+
error_json = response.json()
|
|
382
|
+
error_message = _extract_error_message(error_json)
|
|
383
|
+
raise HTTPException(status_code=status_code, detail=f"Interceptor/provider error: {error_message}")
|
|
384
|
+
except HTTPException:
|
|
385
|
+
raise
|
|
386
|
+
except Exception:
|
|
387
|
+
error_text = response.text[:500] if hasattr(response, "text") else "Unknown error"
|
|
388
|
+
raise HTTPException(
|
|
389
|
+
status_code=status_code,
|
|
390
|
+
detail=f"Interceptor/provider returned error: {error_text}",
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
try:
|
|
394
|
+
response_json = response.json()
|
|
395
|
+
except Exception:
|
|
396
|
+
response_text = response.text
|
|
397
|
+
if status_code >= 400:
|
|
398
|
+
raise HTTPException(status_code=status_code, detail=f"HTTP error: {response_text[:200]}")
|
|
399
|
+
response_json = {}
|
|
400
|
+
except HTTPException:
|
|
401
|
+
raise
|
|
402
|
+
except Exception as exc:
|
|
403
|
+
raise HTTPException(status_code=502, detail=f"Proxy POST failed: {exc}")
|
|
404
|
+
|
|
405
|
+
if response_json is None:
|
|
406
|
+
raise HTTPException(status_code=502, detail="No response data received")
|
|
407
|
+
|
|
408
|
+
response_text = ""
|
|
409
|
+
tool_calls: list[dict[str, Any]] = []
|
|
410
|
+
if isinstance(response_json, dict):
|
|
411
|
+
choices = response_json.get("choices") or []
|
|
412
|
+
if choices:
|
|
413
|
+
message = (choices[0] or {}).get("message", {}) if choices else {}
|
|
414
|
+
response_text = str(message.get("content", "") or "")
|
|
415
|
+
try:
|
|
416
|
+
tool_calls = parse_tool_calls_from_response(
|
|
417
|
+
response_json,
|
|
418
|
+
expected_tool_name=expected_tool_name,
|
|
419
|
+
)
|
|
420
|
+
except ValueError as exc:
|
|
421
|
+
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
|
422
|
+
|
|
423
|
+
if validate_response:
|
|
424
|
+
if not isinstance(response_json, dict) or not response_json:
|
|
425
|
+
raise HTTPException(status_code=502, detail="Proxy returned missing/empty JSON")
|
|
426
|
+
choices = response_json.get("choices") or []
|
|
427
|
+
if not isinstance(choices, list) or len(choices) == 0:
|
|
428
|
+
raise HTTPException(status_code=502, detail="Proxy JSON missing choices")
|
|
429
|
+
first_msg = (choices[0] or {}).get("message", {}) if choices else {}
|
|
430
|
+
if not isinstance(first_msg, dict):
|
|
431
|
+
raise HTTPException(status_code=502, detail="Proxy JSON message malformed")
|
|
432
|
+
content_text = str(first_msg.get("content", ""))
|
|
433
|
+
if not tool_calls and not content_text.strip():
|
|
434
|
+
raise HTTPException(status_code=502, detail="Empty model output: no tool_calls and no content")
|
|
435
|
+
|
|
436
|
+
return response_text, response_json, tool_calls
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def add_health_endpoints(app: FastAPI) -> None:
|
|
440
|
+
"""Add standard /health and /health/rollout endpoints."""
|
|
441
|
+
from synth_ai.sdk.task.auth import is_api_key_header_authorized, normalize_environment_api_key
|
|
442
|
+
|
|
443
|
+
def _log_env_key_prefix(source: str, env_key: str | None) -> str | None:
|
|
444
|
+
if not env_key:
|
|
445
|
+
return None
|
|
446
|
+
prefix = env_key[: max(1, len(env_key) // 2)]
|
|
447
|
+
print(f"[{source}] expected ENVIRONMENT_API_KEY prefix: {prefix}")
|
|
448
|
+
return prefix
|
|
449
|
+
|
|
450
|
+
@app.get("/health")
|
|
451
|
+
async def health(request: Request):
|
|
452
|
+
env_key = normalize_environment_api_key()
|
|
453
|
+
if not env_key:
|
|
454
|
+
return JSONResponse(
|
|
455
|
+
status_code=503,
|
|
456
|
+
content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"},
|
|
457
|
+
)
|
|
458
|
+
if not is_api_key_header_authorized(request):
|
|
459
|
+
prefix = _log_env_key_prefix("health", env_key)
|
|
460
|
+
content = {"status": "healthy", "authorized": False}
|
|
461
|
+
if prefix:
|
|
462
|
+
content["expected_api_key_prefix"] = prefix
|
|
463
|
+
return JSONResponse(status_code=200, content=content)
|
|
464
|
+
return {"status": "healthy", "authorized": True}
|
|
465
|
+
|
|
466
|
+
@app.get("/health/rollout")
|
|
467
|
+
async def health_rollout(request: Request):
|
|
468
|
+
env_key = normalize_environment_api_key()
|
|
469
|
+
if not env_key:
|
|
470
|
+
return JSONResponse(
|
|
471
|
+
status_code=503,
|
|
472
|
+
content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"},
|
|
473
|
+
)
|
|
474
|
+
if not is_api_key_header_authorized(request):
|
|
475
|
+
prefix = _log_env_key_prefix("health/rollout", env_key)
|
|
476
|
+
content = {"status": "healthy", "authorized": False}
|
|
477
|
+
if prefix:
|
|
478
|
+
content["expected_api_key_prefix"] = prefix
|
|
479
|
+
return JSONResponse(status_code=200, content=content)
|
|
480
|
+
return {"ok": True, "authorized": True}
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def add_metadata_endpoint(app: FastAPI) -> None:
|
|
484
|
+
"""Add standard /metadata endpoint."""
|
|
485
|
+
|
|
486
|
+
@app.get("/metadata")
|
|
487
|
+
async def get_metadata(request: Request):
|
|
488
|
+
program_code = get_current_module_source()
|
|
489
|
+
|
|
490
|
+
frame = inspect.currentframe()
|
|
491
|
+
try:
|
|
492
|
+
if frame is None:
|
|
493
|
+
module_path = None
|
|
494
|
+
else:
|
|
495
|
+
caller_frame = frame.f_back
|
|
496
|
+
if caller_frame is None:
|
|
497
|
+
module_path = None
|
|
498
|
+
else:
|
|
499
|
+
module = inspect.getmodule(caller_frame)
|
|
500
|
+
module_path = module.__name__ if module else None
|
|
501
|
+
finally:
|
|
502
|
+
del frame
|
|
503
|
+
|
|
504
|
+
return {
|
|
505
|
+
"program_code": program_code,
|
|
506
|
+
"module_path": module_path,
|
|
507
|
+
"extraction_method": "inspect",
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def _extract_error_message(error_json: Any) -> str:
|
|
512
|
+
if isinstance(error_json, dict):
|
|
513
|
+
error_obj = error_json.get("error")
|
|
514
|
+
if isinstance(error_obj, dict):
|
|
515
|
+
return error_obj.get("message") or error_obj.get("detail") or str(error_obj)
|
|
516
|
+
if isinstance(error_obj, str):
|
|
517
|
+
return error_obj
|
|
518
|
+
return error_json.get("detail") or str(error_json.get("error", "Unknown error"))
|
|
519
|
+
return str(error_json)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""Helpers for building LocalAPI rollout responses.
|
|
2
|
+
|
|
3
|
+
## Usage
|
|
4
|
+
|
|
5
|
+
response = RolloutResponseBuilder.trace_only(
|
|
6
|
+
run_id=request.run_id,
|
|
7
|
+
reward=1.0,
|
|
8
|
+
trace=trace_payload,
|
|
9
|
+
trace_correlation_id="trace_abc123",
|
|
10
|
+
inference_url="https://api.usesynth.ai/v1/trial-xyz",
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
## Key Fields
|
|
14
|
+
|
|
15
|
+
- `reward`: The outcome reward (required) → `metrics.outcome_reward`
|
|
16
|
+
- `trace_correlation_id`: Correlation ID for trace recovery (top-level)
|
|
17
|
+
- `inference_url`: Inference URL used (top-level)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from synth_ai.sdk.task.contracts import RolloutMetrics, RolloutResponse
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RolloutResponseBuilder:
|
|
28
|
+
"""Convenience builders for rollout responses."""
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def trace_only(
|
|
32
|
+
*,
|
|
33
|
+
run_id: str,
|
|
34
|
+
reward: float,
|
|
35
|
+
trace: dict[str, Any] | None,
|
|
36
|
+
event_rewards: list[float] | None = None,
|
|
37
|
+
trace_correlation_id: str | None = None,
|
|
38
|
+
inference_url: str | None = None,
|
|
39
|
+
details: dict[str, Any] | None = None,
|
|
40
|
+
aborted: bool = False,
|
|
41
|
+
) -> RolloutResponse:
|
|
42
|
+
"""Build a RolloutResponse with standardized metrics.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
run_id: Request run_id to echo back
|
|
46
|
+
reward: Outcome reward for this rollout
|
|
47
|
+
trace: v3 trace payload
|
|
48
|
+
event_rewards: Optional per-step rewards for multi-step tasks
|
|
49
|
+
trace_correlation_id: Correlation ID for trace recovery
|
|
50
|
+
inference_url: Inference URL used for this rollout
|
|
51
|
+
details: Metadata dict (debugging info, not rewards)
|
|
52
|
+
aborted: Whether rollout was aborted early
|
|
53
|
+
"""
|
|
54
|
+
metrics = RolloutMetrics(
|
|
55
|
+
outcome_reward=float(reward),
|
|
56
|
+
event_rewards=event_rewards,
|
|
57
|
+
details=details or {},
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return RolloutResponse(
|
|
61
|
+
run_id=run_id,
|
|
62
|
+
metrics=metrics,
|
|
63
|
+
trace=_with_trace_metadata(trace, trace_correlation_id),
|
|
64
|
+
trace_correlation_id=trace_correlation_id,
|
|
65
|
+
inference_url=inference_url,
|
|
66
|
+
aborted=aborted,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _with_trace_metadata(
|
|
71
|
+
trace: dict[str, Any] | None,
|
|
72
|
+
trace_correlation_id: str | None,
|
|
73
|
+
) -> dict[str, Any] | None:
|
|
74
|
+
if trace is None:
|
|
75
|
+
return None
|
|
76
|
+
if not isinstance(trace, dict):
|
|
77
|
+
return trace
|
|
78
|
+
|
|
79
|
+
updated = dict(trace)
|
|
80
|
+
metadata = updated.get("metadata")
|
|
81
|
+
if not isinstance(metadata, dict):
|
|
82
|
+
metadata = {}
|
|
83
|
+
if trace_correlation_id:
|
|
84
|
+
metadata.setdefault("trace_correlation_id", trace_correlation_id)
|
|
85
|
+
corr_ids = metadata.get("correlation_ids")
|
|
86
|
+
if isinstance(corr_ids, dict):
|
|
87
|
+
corr_map = dict(corr_ids)
|
|
88
|
+
else:
|
|
89
|
+
corr_map = {}
|
|
90
|
+
corr_map.setdefault("trace_correlation_id", trace_correlation_id)
|
|
91
|
+
metadata["correlation_ids"] = corr_map
|
|
92
|
+
updated["metadata"] = metadata
|
|
93
|
+
return updated
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""LocalAPI server config re-exports.
|
|
2
|
+
|
|
3
|
+
Prefer this module over synth_ai.sdk.task.server.* moving forward.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from synth_ai.sdk.task.server import (
|
|
9
|
+
LocalAPIConfig,
|
|
10
|
+
ProxyConfig,
|
|
11
|
+
RubricBundle,
|
|
12
|
+
TaskAppConfig,
|
|
13
|
+
create_task_app,
|
|
14
|
+
run_task_app,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
create_local_api = create_task_app
|
|
18
|
+
run_local_api = run_task_app
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"LocalAPIConfig",
|
|
22
|
+
"TaskAppConfig",
|
|
23
|
+
"ProxyConfig",
|
|
24
|
+
"RubricBundle",
|
|
25
|
+
"create_task_app",
|
|
26
|
+
"create_local_api",
|
|
27
|
+
"run_task_app",
|
|
28
|
+
"run_local_api",
|
|
29
|
+
]
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""LocalAPI template utilities (stub).
|
|
2
|
+
|
|
3
|
+
This module provides template building utilities for LocalAPI task apps.
|
|
4
|
+
Currently a minimal stub - full implementation pending.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Any, Callable
|
|
10
|
+
|
|
11
|
+
from synth_ai.sdk.localapi import LocalAPIConfig, create_local_api
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def build_template_config(
|
|
15
|
+
app_id: str = "template",
|
|
16
|
+
name: str = "Template Task App",
|
|
17
|
+
description: str = "A template task app.",
|
|
18
|
+
**kwargs: Any,
|
|
19
|
+
) -> LocalAPIConfig:
|
|
20
|
+
"""Build a minimal LocalAPIConfig for testing/scaffolding.
|
|
21
|
+
|
|
22
|
+
This is a placeholder - real task apps should build their own config.
|
|
23
|
+
"""
|
|
24
|
+
from synth_ai.sdk.task.contracts import RolloutRequest, RolloutResponse, RolloutMetrics
|
|
25
|
+
|
|
26
|
+
async def stub_rollout(request: RolloutRequest, http_request: Any) -> RolloutResponse:
|
|
27
|
+
"""Stub rollout that returns empty metrics."""
|
|
28
|
+
return RolloutResponse(
|
|
29
|
+
run_id=request.run_id,
|
|
30
|
+
metrics=RolloutMetrics(outcome_reward=0.0),
|
|
31
|
+
trace={"event_history": [], "metadata": {}},
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
return LocalAPIConfig(
|
|
35
|
+
app_id=app_id,
|
|
36
|
+
name=name,
|
|
37
|
+
description=description,
|
|
38
|
+
provide_taskset_description=lambda: {"id": app_id, "splits": ["default"]},
|
|
39
|
+
provide_task_instances=lambda seeds: [],
|
|
40
|
+
rollout=stub_rollout,
|
|
41
|
+
# base_task_info is auto-derived from app_id/name
|
|
42
|
+
**kwargs,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def create_template_app(**kwargs: Any):
|
|
47
|
+
"""Create a template FastAPI app for testing/scaffolding."""
|
|
48
|
+
config = build_template_config(**kwargs)
|
|
49
|
+
return create_local_api(config)
|