synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
- examples/rl/configs/rl_from_base_qwen17.toml +1 -0
- examples/swe/task_app/hosted/inference/openai_client.py +0 -34
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/task_app.py +254 -36
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
- synth_ai/api/train/builders.py +90 -1
- synth_ai/api/train/cli.py +396 -21
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +15 -1
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +29 -0
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +85 -17
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +1 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/eval/core.py +13 -10
- synth_ai/cli/commands/filter/core.py +53 -17
- synth_ai/cli/commands/help/core.py +0 -1
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/train/judge_schemas.py +1 -0
- synth_ai/cli/commands/train/judge_validation.py +1 -0
- synth_ai/cli/commands/train/validation.py +0 -57
- synth_ai/cli/demo.py +35 -3
- synth_ai/cli/deploy/__init__.py +40 -25
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/task_app_deploy.py +1 -1
- synth_ai/cli/task_apps.py +53 -53
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/judge_schemas.py +1 -0
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/handlers.py +53 -4
- synth_ai/streaming/streamer.py +19 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +44 -8
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +17 -17
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +283 -1
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
- synth_ai/cli/commands/deploy/__init__.py +0 -23
- synth_ai/cli/commands/deploy/core.py +0 -614
- synth_ai/cli/commands/deploy/errors.py +0 -72
- synth_ai/cli/commands/deploy/validation.py +0 -11
- synth_ai/cli/deploy/core.py +0 -5
- synth_ai/cli/deploy/errors.py +0 -23
- synth_ai/cli/deploy/validation.py +0 -5
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import contextlib
|
|
5
5
|
import logging
|
|
6
|
+
import os
|
|
6
7
|
from typing import Any
|
|
7
8
|
|
|
8
9
|
import httpx
|
|
@@ -23,6 +24,15 @@ class OpenAIClient:
|
|
|
23
24
|
self.api_key = api_key
|
|
24
25
|
self.timeout_s = timeout_s
|
|
25
26
|
self.headers = {}
|
|
27
|
+
self._env_api_key: str | None = None
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
env_key = os.getenv("ENVIRONMENT_API_KEY") or ""
|
|
31
|
+
env_key = env_key.strip()
|
|
32
|
+
if env_key:
|
|
33
|
+
self._env_api_key = env_key
|
|
34
|
+
except Exception:
|
|
35
|
+
self._env_api_key = None
|
|
26
36
|
|
|
27
37
|
if api_key:
|
|
28
38
|
self.headers["Authorization"] = f"Bearer {api_key}"
|
|
@@ -137,18 +147,49 @@ class OpenAIClient:
|
|
|
137
147
|
Returns:
|
|
138
148
|
OpenAI-compatible chat completion response
|
|
139
149
|
"""
|
|
140
|
-
|
|
150
|
+
# Build target URL robustly: if a full endpoint is given (with query or already ending
|
|
151
|
+
# in /chat/completions), preserve it; otherwise, append the path BEFORE query params.
|
|
152
|
+
from urllib.parse import urlparse, urlunparse
|
|
153
|
+
|
|
154
|
+
candidate = (base_url or self.base_url).strip()
|
|
155
|
+
try:
|
|
156
|
+
parsed = urlparse(candidate)
|
|
157
|
+
# If no scheme, treat as relative base (pass-through)
|
|
158
|
+
if not parsed.scheme or not parsed.netloc:
|
|
159
|
+
base_no_slash = candidate.rstrip("/")
|
|
160
|
+
url = f"{base_no_slash}/v1/chat/completions"
|
|
161
|
+
else:
|
|
162
|
+
path = (parsed.path or "").rstrip("/")
|
|
163
|
+
if path.endswith("/v1/chat/completions") or path.endswith("/chat/completions"):
|
|
164
|
+
new_path = path
|
|
165
|
+
elif path.endswith("/v1"):
|
|
166
|
+
new_path = f"{path}/chat/completions"
|
|
167
|
+
elif path.endswith("/chat"):
|
|
168
|
+
new_path = f"{path}/completions"
|
|
169
|
+
else:
|
|
170
|
+
new_path = f"{path}/v1/chat/completions" if path else "/v1/chat/completions"
|
|
171
|
+
url = urlunparse(parsed._replace(path=new_path))
|
|
172
|
+
except Exception:
|
|
173
|
+
# Fallback to legacy behavior
|
|
174
|
+
url = (base_url or self.base_url).rstrip("/") + "/v1/chat/completions"
|
|
141
175
|
timeout = timeout_s or self.timeout_s
|
|
142
176
|
|
|
143
177
|
# Merge headers
|
|
144
178
|
headers = self.headers.copy()
|
|
179
|
+
try:
|
|
180
|
+
parsed_target = urlparse(url)
|
|
181
|
+
path_for_auth = (parsed_target.path or "") if parsed_target else ""
|
|
182
|
+
if self._env_api_key and "/proxy/" in path_for_auth:
|
|
183
|
+
headers.setdefault("X-API-Key", self._env_api_key)
|
|
184
|
+
except Exception:
|
|
185
|
+
pass
|
|
145
186
|
if extra_headers:
|
|
146
187
|
headers.update(extra_headers)
|
|
147
188
|
|
|
148
189
|
# Fix parameter compatibility for newer models
|
|
149
190
|
processed_request = self._fix_model_parameters(request, target_url=url)
|
|
150
191
|
|
|
151
|
-
# Log request (
|
|
192
|
+
# Log request with detailed prompts/tools preview and sampling settings (Authorization is not logged)
|
|
152
193
|
logger.info(f"Inference POST target: {url}")
|
|
153
194
|
if extra_headers:
|
|
154
195
|
logger.info(f"Extra headers: {extra_headers}")
|
|
@@ -156,6 +197,62 @@ class OpenAIClient:
|
|
|
156
197
|
keys_preview = sorted(processed_request.keys())
|
|
157
198
|
logger.info(f"Request keys: {keys_preview}")
|
|
158
199
|
|
|
200
|
+
# Detailed IO log: messages/tools/sampling and final payload fields
|
|
201
|
+
try:
|
|
202
|
+
import json as _json
|
|
203
|
+
|
|
204
|
+
def _truncate(text: str, limit: int = 2000) -> str:
|
|
205
|
+
return text if len(text) <= limit else text[:limit] + "…"
|
|
206
|
+
|
|
207
|
+
def _messages_preview(msgs: Any) -> str:
|
|
208
|
+
try:
|
|
209
|
+
out: list[dict[str, Any]] = []
|
|
210
|
+
if isinstance(msgs, list):
|
|
211
|
+
for m in msgs:
|
|
212
|
+
if not isinstance(m, dict):
|
|
213
|
+
continue
|
|
214
|
+
role = m.get("role")
|
|
215
|
+
content = m.get("content")
|
|
216
|
+
if isinstance(content, str):
|
|
217
|
+
text = content
|
|
218
|
+
elif isinstance(content, list):
|
|
219
|
+
parts: list[str] = []
|
|
220
|
+
for seg in content:
|
|
221
|
+
if isinstance(seg, dict) and isinstance(seg.get("text"), str):
|
|
222
|
+
parts.append(seg["text"])
|
|
223
|
+
text = "\n".join(parts)
|
|
224
|
+
else:
|
|
225
|
+
text = ""
|
|
226
|
+
out.append({"role": role, "content": _truncate(str(text), 4000)})
|
|
227
|
+
return _json.dumps(out)
|
|
228
|
+
except Exception:
|
|
229
|
+
return "[]"
|
|
230
|
+
|
|
231
|
+
def _tools_preview(tools: Any) -> str:
|
|
232
|
+
try:
|
|
233
|
+
return _truncate(_json.dumps(tools), 4000)
|
|
234
|
+
except Exception:
|
|
235
|
+
return "[]"
|
|
236
|
+
|
|
237
|
+
msgs = processed_request.get("messages") if isinstance(processed_request, dict) else None
|
|
238
|
+
tools = processed_request.get("tools") if isinstance(processed_request, dict) else None
|
|
239
|
+
io_log: dict[str, Any] = {
|
|
240
|
+
"llm.call": True,
|
|
241
|
+
"model": processed_request.get("model") if isinstance(processed_request, dict) else None,
|
|
242
|
+
"tool_choice": processed_request.get("tool_choice") if isinstance(processed_request, dict) else None,
|
|
243
|
+
"parallel_tool_calls": processed_request.get("parallel_tool_calls") if isinstance(processed_request, dict) else None,
|
|
244
|
+
"stop_after_tool_calls": processed_request.get("stop_after_tool_calls") if isinstance(processed_request, dict) else None,
|
|
245
|
+
"temperature": processed_request.get("temperature") if isinstance(processed_request, dict) else None,
|
|
246
|
+
"top_p": processed_request.get("top_p") if isinstance(processed_request, dict) else None,
|
|
247
|
+
"max_tokens": processed_request.get("max_tokens") if isinstance(processed_request, dict) else None,
|
|
248
|
+
"max_completion_tokens": processed_request.get("max_completion_tokens") if isinstance(processed_request, dict) else None,
|
|
249
|
+
"messages_preview": _messages_preview(msgs),
|
|
250
|
+
"tools_preview": _tools_preview(tools),
|
|
251
|
+
}
|
|
252
|
+
logger.info(io_log)
|
|
253
|
+
except Exception:
|
|
254
|
+
pass
|
|
255
|
+
|
|
159
256
|
# Final hard-guard for OpenAI/Groq: ensure unsupported field is not present
|
|
160
257
|
try:
|
|
161
258
|
low_url = url.lower()
|
|
@@ -228,13 +325,54 @@ class OpenAIClient:
|
|
|
228
325
|
f"Inference response status=200, content-type={content_type}, bytes={len(body_text)}"
|
|
229
326
|
)
|
|
230
327
|
if body_text:
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
328
|
+
# Log raw output with generous preview to debug no-tool-call issues
|
|
329
|
+
preview_len = min(4000, len(body_text))
|
|
330
|
+
logger.info({
|
|
331
|
+
"llm.raw_response": True,
|
|
332
|
+
"bytes": len(body_text),
|
|
333
|
+
"preview": body_text[:preview_len],
|
|
334
|
+
})
|
|
235
335
|
|
|
236
336
|
result = response.json()
|
|
237
337
|
logger.info(f"Inference response parsed_type={type(result).__name__}")
|
|
338
|
+
|
|
339
|
+
# Normalize tool calls so downstream always sees a function tool call
|
|
340
|
+
try:
|
|
341
|
+
if isinstance(result, dict):
|
|
342
|
+
choices = result.get("choices")
|
|
343
|
+
if isinstance(choices, list) and choices:
|
|
344
|
+
msg = choices[0].get("message")
|
|
345
|
+
if isinstance(msg, dict):
|
|
346
|
+
# Prefer tool_calls; if missing but function_call is present, synthesize tool_calls
|
|
347
|
+
tc = msg.get("tool_calls")
|
|
348
|
+
fc = msg.get("function_call")
|
|
349
|
+
if (not isinstance(tc, list) or not tc) and isinstance(fc, dict):
|
|
350
|
+
name = fc.get("name") or "interact_many"
|
|
351
|
+
args = fc.get("arguments") or "{}"
|
|
352
|
+
msg["tool_calls"] = [
|
|
353
|
+
{
|
|
354
|
+
"id": "call_norm",
|
|
355
|
+
"type": "function",
|
|
356
|
+
"function": {"name": name, "arguments": args},
|
|
357
|
+
}
|
|
358
|
+
]
|
|
359
|
+
# Encourage downstream to treat this as a tool call
|
|
360
|
+
if isinstance(choices[0], dict):
|
|
361
|
+
choices[0]["finish_reason"] = "tool_calls"
|
|
362
|
+
# Log tool call count for debugging
|
|
363
|
+
try:
|
|
364
|
+
tc2 = msg.get("tool_calls")
|
|
365
|
+
count = len(tc2) if isinstance(tc2, list) else 0
|
|
366
|
+
logger.info({
|
|
367
|
+
"llm.tool_calls": True,
|
|
368
|
+
"count": count,
|
|
369
|
+
"finish_reason": choices[0].get("finish_reason") if isinstance(choices[0], dict) else None,
|
|
370
|
+
})
|
|
371
|
+
except Exception:
|
|
372
|
+
pass
|
|
373
|
+
except Exception:
|
|
374
|
+
pass
|
|
375
|
+
|
|
238
376
|
return result
|
|
239
377
|
|
|
240
378
|
except httpx.TimeoutException:
|
|
@@ -340,40 +478,6 @@ class OpenAIClient:
|
|
|
340
478
|
pass
|
|
341
479
|
except Exception:
|
|
342
480
|
pass
|
|
343
|
-
# Gracefully degrade on 422 so rollouts can still produce a trajectory
|
|
344
|
-
if status == 422:
|
|
345
|
-
try:
|
|
346
|
-
# Best-effort parse of error for diagnostics
|
|
347
|
-
err = None
|
|
348
|
-
try:
|
|
349
|
-
err = e.response.json()
|
|
350
|
-
except Exception:
|
|
351
|
-
err = {"error": "unprocessable", "detail": (text or "")[:200]}
|
|
352
|
-
logger.warning(
|
|
353
|
-
{
|
|
354
|
-
"inference_422_recovered": True,
|
|
355
|
-
"detail": err,
|
|
356
|
-
}
|
|
357
|
-
)
|
|
358
|
-
except Exception:
|
|
359
|
-
pass
|
|
360
|
-
# Return a minimal OpenAI-compatible response with no tool_calls/content
|
|
361
|
-
import time as _t
|
|
362
|
-
|
|
363
|
-
return {
|
|
364
|
-
"id": f"cmpl-{int(_t.time())}",
|
|
365
|
-
"object": "chat.completion",
|
|
366
|
-
"created": int(_t.time()),
|
|
367
|
-
"model": processed_request.get("model") or "unknown",
|
|
368
|
-
"choices": [
|
|
369
|
-
{
|
|
370
|
-
"index": 0,
|
|
371
|
-
"message": {"role": "assistant", "content": "", "tool_calls": []},
|
|
372
|
-
"finish_reason": "stop",
|
|
373
|
-
}
|
|
374
|
-
],
|
|
375
|
-
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
376
|
-
}
|
|
377
481
|
raise
|
|
378
482
|
except Exception as e:
|
|
379
483
|
logger.error(f"Unexpected error calling {url}: {e}")
|
|
@@ -399,7 +503,14 @@ class OpenAIClient:
|
|
|
399
503
|
|
|
400
504
|
try:
|
|
401
505
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
402
|
-
|
|
506
|
+
headers = self.headers.copy()
|
|
507
|
+
try:
|
|
508
|
+
parsed = httpx.URL(url)
|
|
509
|
+
if self._env_api_key and "/proxy/" in (parsed.path or ""):
|
|
510
|
+
headers.setdefault("X-API-Key", self._env_api_key)
|
|
511
|
+
except Exception:
|
|
512
|
+
pass
|
|
513
|
+
response = await client.get(url, headers=headers)
|
|
403
514
|
response.raise_for_status()
|
|
404
515
|
return response.json()
|
|
405
516
|
except httpx.HTTPStatusError as e:
|
|
@@ -466,11 +466,20 @@ async def step_policy(
|
|
|
466
466
|
|
|
467
467
|
if tracing_context is not None:
|
|
468
468
|
try:
|
|
469
|
+
print(
|
|
470
|
+
f"[TRACE_DEBUG] record_policy_prompts sys={len(system_prompt_records)} user={len(user_prompt_records)}",
|
|
471
|
+
flush=True,
|
|
472
|
+
)
|
|
469
473
|
await tracing_context.record_policy_prompts(
|
|
470
474
|
system_prompt_records, user_prompt_records
|
|
471
475
|
)
|
|
472
476
|
except Exception as exc:
|
|
473
477
|
logger.debug(f"TRACING_PROMPTS_FAIL: {exc}")
|
|
478
|
+
else:
|
|
479
|
+
print(
|
|
480
|
+
f"[TRACE_DEBUG] Missing tracing context on policy step; policy_id={request.policy_id}",
|
|
481
|
+
flush=True,
|
|
482
|
+
)
|
|
474
483
|
|
|
475
484
|
# Create inference client (choose API key by target provider)
|
|
476
485
|
# Require inference_url to be set explicitly by the rollout policy config.
|
|
@@ -492,7 +501,11 @@ async def step_policy(
|
|
|
492
501
|
if isinstance(target_url, str):
|
|
493
502
|
low_url = target_url.lower()
|
|
494
503
|
# Proxy endpoints should not receive a bearer; the server-side proxy holds the vendor key
|
|
495
|
-
if
|
|
504
|
+
if (
|
|
505
|
+
"/proxy/groq" in low_url
|
|
506
|
+
or "/proxy/openai" in low_url
|
|
507
|
+
or "/proxy/v1" in low_url
|
|
508
|
+
):
|
|
496
509
|
api_key_override = None
|
|
497
510
|
elif "openai.com" in low_url:
|
|
498
511
|
api_key_override = _os.getenv("OPENAI_API_KEY") or getattr(
|
|
@@ -954,6 +967,23 @@ async def step_policy(
|
|
|
954
967
|
except Exception as exc:
|
|
955
968
|
logger.debug(f"TRACING_LLM_FAIL: {exc}")
|
|
956
969
|
|
|
970
|
+
if not tool_calls:
|
|
971
|
+
preview = ""
|
|
972
|
+
try:
|
|
973
|
+
preview = str(meta.get("raw_response") or "")[:400]
|
|
974
|
+
except Exception:
|
|
975
|
+
preview = "<unavailable>"
|
|
976
|
+
logger.error(
|
|
977
|
+
{
|
|
978
|
+
"rollout.policy_step": True,
|
|
979
|
+
"policy_id": request.policy_id,
|
|
980
|
+
"error": "no_tool_calls",
|
|
981
|
+
"inference_url": meta.get("inference_url"),
|
|
982
|
+
"raw_preview": preview,
|
|
983
|
+
}
|
|
984
|
+
)
|
|
985
|
+
raise RuntimeError("Policy step produced no tool calls; inference response unusable.")
|
|
986
|
+
|
|
957
987
|
return PolicyStepResponse(
|
|
958
988
|
tool_calls=tool_calls,
|
|
959
989
|
meta=meta,
|
|
@@ -223,6 +223,7 @@ class RolloutTracingContext:
|
|
|
223
223
|
).lower()
|
|
224
224
|
self.return_trace = bool(getattr(request.record, "return_trace", False))
|
|
225
225
|
self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
|
|
226
|
+
print(f"[TRACE_DEBUG] RolloutTracingContext init: trace_format={self.trace_format} return_trace={self.return_trace}", flush=True)
|
|
226
227
|
self.session_trace = None
|
|
227
228
|
self.metadata_updates: dict[str, Any] = {}
|
|
228
229
|
self.policy_name = request.policy.policy_name or ""
|
|
@@ -244,19 +245,24 @@ class RolloutTracingContext:
|
|
|
244
245
|
|
|
245
246
|
async def start_session(self) -> None:
|
|
246
247
|
if not self.enabled or self.tracer is None:
|
|
248
|
+
print("[TRACE_DEBUG] start_session skipped: tracer disabled", flush=True)
|
|
247
249
|
return
|
|
248
250
|
try:
|
|
249
251
|
await self.tracer.initialize()
|
|
252
|
+
print("[TRACE_DEBUG] tracer initialized", flush=True)
|
|
250
253
|
except Exception as exc:
|
|
251
254
|
logger.debug("TRACING_INIT_FAIL: %s", exc)
|
|
255
|
+
# Hard fail: tracing requested but cannot initialize
|
|
256
|
+
raise
|
|
252
257
|
try:
|
|
253
258
|
await self.tracer.start_session(
|
|
254
259
|
session_id=self.run_id, metadata=dict(self.metadata_base)
|
|
255
260
|
)
|
|
261
|
+
print(f"[TRACE_DEBUG] start_session succeeded for run_id={self.run_id}", flush=True)
|
|
256
262
|
except Exception as exc:
|
|
257
263
|
logger.warning("TRACING_START_FAIL: %s", exc)
|
|
258
|
-
|
|
259
|
-
|
|
264
|
+
# Hard fail: tracing requested but cannot start session
|
|
265
|
+
raise
|
|
260
266
|
|
|
261
267
|
async def start_decision(self, turn_number: int) -> None:
|
|
262
268
|
self.current_turn = turn_number
|
|
@@ -317,6 +323,9 @@ class RolloutTracingContext:
|
|
|
317
323
|
)
|
|
318
324
|
except Exception as exc:
|
|
319
325
|
logger.debug("TRACING_USER_MSG_FAIL: %s", exc)
|
|
326
|
+
if self.tracer and self.tracer._current_trace:
|
|
327
|
+
msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
|
|
328
|
+
print(f"[TRACE_DEBUG] After record_policy_prompts: {msg_count} messages", flush=True)
|
|
320
329
|
|
|
321
330
|
def _content_to_text(self, content: Any) -> str:
|
|
322
331
|
if isinstance(content, str):
|
|
@@ -395,6 +404,11 @@ class RolloutTracingContext:
|
|
|
395
404
|
message_type="policy_tool_call",
|
|
396
405
|
metadata=self._message_metadata(),
|
|
397
406
|
)
|
|
407
|
+
if self.tracer._current_trace:
|
|
408
|
+
print(
|
|
409
|
+
f"[TRACE_DEBUG] After tool invocation: messages={len(self.tracer._current_trace.markov_blanket_message_history)}",
|
|
410
|
+
flush=True,
|
|
411
|
+
)
|
|
398
412
|
except Exception as exc:
|
|
399
413
|
logger.debug("TRACING_TOOL_MSG_FAIL: %s", exc)
|
|
400
414
|
|
|
@@ -664,12 +678,24 @@ class RolloutTracingContext:
|
|
|
664
678
|
except Exception as exc:
|
|
665
679
|
logger.debug("TRACING_OUTCOME_FAIL: %s", exc)
|
|
666
680
|
try:
|
|
681
|
+
if self.tracer._current_trace:
|
|
682
|
+
msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
|
|
683
|
+
print(f"[TRACE_DEBUG] Before end_session: {msg_count} messages in trace", flush=True)
|
|
667
684
|
self.session_trace = await self.tracer.end_session()
|
|
668
685
|
if self.session_trace is not None:
|
|
669
686
|
self.session_trace.metadata.update(self.metadata_updates)
|
|
687
|
+
print(
|
|
688
|
+
f"[TRACE_DEBUG] Session ended successfully, session_id={self.session_trace.session_id}",
|
|
689
|
+
flush=True,
|
|
690
|
+
)
|
|
691
|
+
print(
|
|
692
|
+
f"[TRACE_DEBUG] session_trace.metadata keys: {list(self.session_trace.metadata.keys())}",
|
|
693
|
+
flush=True,
|
|
694
|
+
)
|
|
670
695
|
except Exception as exc:
|
|
671
696
|
logger.debug("TRACING_END_SESSION_FAIL: %s", exc)
|
|
672
697
|
self.session_trace = None
|
|
698
|
+
print(f"[TRACE_DEBUG] end_session failed for run_id={self.run_id}: {exc}", flush=True)
|
|
673
699
|
with contextlib.suppress(Exception):
|
|
674
700
|
await self.tracer.close()
|
|
675
701
|
|
|
@@ -700,9 +726,13 @@ class RolloutTracingContext:
|
|
|
700
726
|
def build_trace_payload(self, session_trace: Any) -> dict[str, Any] | None:
|
|
701
727
|
if not self.return_trace or session_trace is None:
|
|
702
728
|
return None
|
|
703
|
-
if self.trace_format
|
|
729
|
+
if self.trace_format in ("full", "structured"):
|
|
704
730
|
payload = session_trace.to_dict()
|
|
705
731
|
payload.setdefault("metadata", {}).update(self.metadata_updates)
|
|
732
|
+
print(
|
|
733
|
+
f"[TRACE_DEBUG] build_trace_payload returning structured trace with messages={len(payload.get('markov_blanket_message_history') or [])}",
|
|
734
|
+
flush=True,
|
|
735
|
+
)
|
|
706
736
|
return payload
|
|
707
737
|
metadata = dict(session_trace.metadata)
|
|
708
738
|
metadata.update(self.metadata_updates)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Utility functions for the task service."""
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
|
+
from urllib.parse import urlparse, urlunparse
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
|
|
@@ -60,3 +61,69 @@ def sanitize_observation(observation: dict[str, Any]) -> dict[str, Any]:
|
|
|
60
61
|
sanitized[key] = convert_numpy_to_python(value)
|
|
61
62
|
|
|
62
63
|
return sanitized
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
_CHAT_COMPLETIONS_SUFFIX = "/v1/chat/completions"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def force_normalize_chat_completions_url(raw_url: Any) -> Any:
|
|
70
|
+
"""
|
|
71
|
+
Convert ANY malformed inference URL into the correct chat-completions form.
|
|
72
|
+
Ensures path ends with /v1/chat/completions and that query has no '/' segments.
|
|
73
|
+
"""
|
|
74
|
+
if not isinstance(raw_url, str):
|
|
75
|
+
return raw_url
|
|
76
|
+
url = raw_url.strip()
|
|
77
|
+
if not url:
|
|
78
|
+
return raw_url
|
|
79
|
+
|
|
80
|
+
parsed = urlparse(url)
|
|
81
|
+
path = (parsed.path or "").rstrip("/")
|
|
82
|
+
query = parsed.query or ""
|
|
83
|
+
|
|
84
|
+
# If query contains a path, extract and repair
|
|
85
|
+
if query and "/" in query:
|
|
86
|
+
before_slash, after_slash = query.split("/", 1)
|
|
87
|
+
cut_positions = [i for i in [after_slash.find("&"), after_slash.find("?")] if i >= 0]
|
|
88
|
+
cut = min(cut_positions) if cut_positions else len(after_slash)
|
|
89
|
+
path_from_query = "/" + after_slash[:cut]
|
|
90
|
+
extra_query = after_slash[cut + 1 :] if cut < len(after_slash) else ""
|
|
91
|
+
merged_query = before_slash if before_slash else ""
|
|
92
|
+
if extra_query:
|
|
93
|
+
merged_query = f"{merged_query}&{extra_query}" if merged_query else extra_query
|
|
94
|
+
final_path = (
|
|
95
|
+
path_from_query
|
|
96
|
+
if path_from_query.startswith(_CHAT_COMPLETIONS_SUFFIX)
|
|
97
|
+
else f"{path_from_query.rstrip('/')}{_CHAT_COMPLETIONS_SUFFIX}"
|
|
98
|
+
)
|
|
99
|
+
parsed = parsed._replace(path=final_path, query=merged_query)
|
|
100
|
+
url = urlunparse(parsed)
|
|
101
|
+
parsed = urlparse(url)
|
|
102
|
+
path = parsed.path or ""
|
|
103
|
+
query = parsed.query or ""
|
|
104
|
+
|
|
105
|
+
# Ensure path suffix
|
|
106
|
+
if not path.endswith(_CHAT_COMPLETIONS_SUFFIX):
|
|
107
|
+
new_path = f"{path}{_CHAT_COMPLETIONS_SUFFIX}" if path else _CHAT_COMPLETIONS_SUFFIX
|
|
108
|
+
parsed = parsed._replace(path=new_path)
|
|
109
|
+
url = urlunparse(parsed)
|
|
110
|
+
parsed = urlparse(url)
|
|
111
|
+
path = parsed.path or ""
|
|
112
|
+
query = parsed.query or ""
|
|
113
|
+
|
|
114
|
+
# Last-resort: strip any '/' from query
|
|
115
|
+
if query and "/" in query:
|
|
116
|
+
safe_query = query.split("/")[0]
|
|
117
|
+
parsed = parsed._replace(query=safe_query)
|
|
118
|
+
url = urlunparse(parsed)
|
|
119
|
+
|
|
120
|
+
return url
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def ensure_chat_completions_url(raw_url: Any, mode: Any = None) -> Any:
|
|
124
|
+
"""
|
|
125
|
+
Mode-aware normalizer (RL/EVAL) that returns a valid chat completions URL and
|
|
126
|
+
preserves existing query parameters.
|
|
127
|
+
"""
|
|
128
|
+
# For now reuse force normalizer in both modes to guarantee correctness
|
|
129
|
+
return force_normalize_chat_completions_url(raw_url)
|
synth_ai/api/train/builders.py
CHANGED
|
@@ -33,7 +33,7 @@ try:
|
|
|
33
33
|
except Exception as exc: # pragma: no cover - critical dependency
|
|
34
34
|
raise RuntimeError("Unable to load SFT payload helpers") from exc
|
|
35
35
|
|
|
36
|
-
from .configs import RLConfig, SFTConfig
|
|
36
|
+
from .configs import PromptLearningConfig, RLConfig, SFTConfig
|
|
37
37
|
from .supported_algos import (
|
|
38
38
|
AlgorithmValidationError,
|
|
39
39
|
ensure_model_supported_for_algorithm,
|
|
@@ -56,6 +56,12 @@ class SFTBuildResult:
|
|
|
56
56
|
validation_file: Path | None
|
|
57
57
|
|
|
58
58
|
|
|
59
|
+
@dataclass(slots=True)
|
|
60
|
+
class PromptLearningBuildResult:
|
|
61
|
+
payload: dict[str, Any]
|
|
62
|
+
task_url: str
|
|
63
|
+
|
|
64
|
+
|
|
59
65
|
def _format_validation_error(path: Path, exc: ValidationError) -> str:
|
|
60
66
|
lines: list[str] = []
|
|
61
67
|
for error in exc.errors():
|
|
@@ -86,6 +92,11 @@ def build_rl_payload(
|
|
|
86
92
|
raise click.ClickException(_format_validation_error(config_path, exc)) from exc
|
|
87
93
|
|
|
88
94
|
data = rl_cfg.to_dict()
|
|
95
|
+
|
|
96
|
+
# Remove smoke section - it's CLI-only and should not be sent to the trainer
|
|
97
|
+
if "smoke" in data:
|
|
98
|
+
del data["smoke"]
|
|
99
|
+
|
|
89
100
|
# Ensure required [reference] section for backend validators
|
|
90
101
|
try:
|
|
91
102
|
ref_cfg = data.get("reference") if isinstance(data, dict) else None
|
|
@@ -349,9 +360,87 @@ def build_sft_payload(
|
|
|
349
360
|
return SFTBuildResult(payload=payload, train_file=dataset_path, validation_file=validation_file)
|
|
350
361
|
|
|
351
362
|
|
|
363
|
+
def build_prompt_learning_payload(
|
|
364
|
+
*,
|
|
365
|
+
config_path: Path,
|
|
366
|
+
task_url: str | None,
|
|
367
|
+
overrides: dict[str, Any],
|
|
368
|
+
allow_experimental: bool | None = None,
|
|
369
|
+
) -> PromptLearningBuildResult:
|
|
370
|
+
"""Build payload for prompt learning job (MIPRO or GEPA)."""
|
|
371
|
+
import os
|
|
372
|
+
|
|
373
|
+
from pydantic import ValidationError
|
|
374
|
+
|
|
375
|
+
from .configs.prompt_learning import load_toml
|
|
376
|
+
|
|
377
|
+
# SDK-SIDE VALIDATION: Catch errors BEFORE sending to backend
|
|
378
|
+
from .validators import validate_prompt_learning_config
|
|
379
|
+
|
|
380
|
+
raw_config = load_toml(config_path)
|
|
381
|
+
validate_prompt_learning_config(raw_config, config_path)
|
|
382
|
+
|
|
383
|
+
try:
|
|
384
|
+
pl_cfg = PromptLearningConfig.from_path(config_path)
|
|
385
|
+
except ValidationError as exc:
|
|
386
|
+
raise click.ClickException(_format_validation_error(config_path, exc)) from exc
|
|
387
|
+
|
|
388
|
+
# Source of truth: TOML only (ignore shell/env and CLI overrides)
|
|
389
|
+
final_task_url = (pl_cfg.task_app_url or "").strip()
|
|
390
|
+
|
|
391
|
+
if not final_task_url:
|
|
392
|
+
raise click.ClickException(
|
|
393
|
+
"Task app URL required (provide --task-url or set prompt_learning.task_app_url in TOML)"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
# Get task_app_api_key from config or environment
|
|
397
|
+
task_app_api_key = (
|
|
398
|
+
pl_cfg.task_app_api_key
|
|
399
|
+
or os.environ.get("ENVIRONMENT_API_KEY", "")
|
|
400
|
+
).strip()
|
|
401
|
+
|
|
402
|
+
if not task_app_api_key:
|
|
403
|
+
raise click.ClickException(
|
|
404
|
+
"Task app API key required (set prompt_learning.task_app_api_key in TOML or ENVIRONMENT_API_KEY env var)"
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# Build config dict for backend
|
|
408
|
+
config_dict = pl_cfg.to_dict()
|
|
409
|
+
|
|
410
|
+
# Ensure task_app_url and task_app_api_key are set
|
|
411
|
+
pl_section = config_dict.get("prompt_learning", {})
|
|
412
|
+
if isinstance(pl_section, dict):
|
|
413
|
+
pl_section["task_app_url"] = final_task_url
|
|
414
|
+
pl_section["task_app_api_key"] = task_app_api_key
|
|
415
|
+
else:
|
|
416
|
+
config_dict["prompt_learning"] = {
|
|
417
|
+
"task_app_url": final_task_url,
|
|
418
|
+
"task_app_api_key": task_app_api_key,
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
# Build payload matching backend API format
|
|
422
|
+
payload: dict[str, Any] = {
|
|
423
|
+
"algorithm": pl_cfg.algorithm,
|
|
424
|
+
"config_body": config_dict,
|
|
425
|
+
"overrides": overrides.get("overrides", {}),
|
|
426
|
+
"metadata": overrides.get("metadata", {}),
|
|
427
|
+
"auto_start": overrides.get("auto_start", True),
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
backend = overrides.get("backend")
|
|
431
|
+
if backend:
|
|
432
|
+
metadata_default: dict[str, Any] = {}
|
|
433
|
+
metadata = cast(dict[str, Any], payload.setdefault("metadata", metadata_default))
|
|
434
|
+
metadata["backend_base_url"] = ensure_api_base(str(backend))
|
|
435
|
+
|
|
436
|
+
return PromptLearningBuildResult(payload=payload, task_url=final_task_url)
|
|
437
|
+
|
|
438
|
+
|
|
352
439
|
__all__ = [
|
|
440
|
+
"PromptLearningBuildResult",
|
|
353
441
|
"RLBuildResult",
|
|
354
442
|
"SFTBuildResult",
|
|
443
|
+
"build_prompt_learning_payload",
|
|
355
444
|
"build_rl_payload",
|
|
356
445
|
"build_sft_payload",
|
|
357
446
|
]
|