synth-ai 0.2.13.dev2__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +5 -4
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
- examples/sft/evaluate.py +2 -0
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +1 -0
- examples/swe/task_app/hosted/rollout.py +2 -0
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +3 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +306 -8
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +16 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +52 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +111 -13
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +156 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +2 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +2 -0
- examples/task_apps/pokemon_red/task_app.py +199 -6
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +2 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +8 -4
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +258 -23
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +2 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- synth_ai/api/models/supported.py +1 -0
- synth_ai/cli/__init__.py +46 -13
- synth_ai/cli/_modal_wrapper.py +3 -2
- synth_ai/cli/recent.py +1 -1
- synth_ai/cli/status.py +1 -1
- synth_ai/cli/task_apps.py +354 -143
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/tui.py +57 -0
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/verilog/engine.py +76 -10
- synth_ai/judge_schemas.py +8 -8
- synth_ai/task/__init__.py +11 -1
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +257 -0
- synth_ai/task/contracts.py +15 -2
- synth_ai/task/rubrics/__init__.py +3 -0
- synth_ai/task/rubrics/loaders.py +22 -3
- synth_ai/task/rubrics/scoring.py +3 -0
- synth_ai/task/trace_correlation_helpers.py +315 -0
- synth_ai/task/validators.py +144 -0
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/session_tracer.py +16 -6
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +8 -7
- synth_ai/tracing_v3/turso/native_manager.py +63 -40
- synth_ai/tracing_v3/utils.py +3 -3
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +906 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/METADATA +1 -1
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/RECORD +110 -71
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -13,6 +13,7 @@ from pydantic import BaseModel, Field
|
|
|
13
13
|
from synth_ai.lm.vendors.base import BaseLMResponse
|
|
14
14
|
from synth_ai.task.tracing_utils import unique_sft_path
|
|
15
15
|
from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
|
|
16
|
+
from synth_ai.task.contracts import RolloutMode
|
|
16
17
|
from synth_ai.tracing_v3.llm_call_record_helpers import create_llm_call_record_from_response
|
|
17
18
|
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
18
19
|
|
|
@@ -120,6 +121,8 @@ class RolloutRequest(BaseModel):
|
|
|
120
121
|
# Optional run/session context
|
|
121
122
|
training_session_id: str | None = None
|
|
122
123
|
synth_base_url: str | None = None
|
|
124
|
+
# Mode controls URL transformation: REQUIRED to make intent explicit
|
|
125
|
+
mode: RolloutMode
|
|
123
126
|
|
|
124
127
|
|
|
125
128
|
class RolloutStep(BaseModel):
|
|
@@ -140,6 +143,7 @@ class RolloutTrajectory(BaseModel):
|
|
|
140
143
|
final: dict[str, Any] | None = None
|
|
141
144
|
length: int
|
|
142
145
|
decision_samples: list[dict[str, Any]] | None = None
|
|
146
|
+
inference_url: str | None = None
|
|
143
147
|
|
|
144
148
|
|
|
145
149
|
def _normalize_step_strategy(raw_strategy: Any) -> str:
|
|
@@ -452,11 +456,12 @@ class RolloutMetrics(BaseModel):
|
|
|
452
456
|
class RolloutResponse(BaseModel):
|
|
453
457
|
run_id: str
|
|
454
458
|
trajectories: list[RolloutTrajectory]
|
|
455
|
-
branches: dict[str, list[str]] =
|
|
459
|
+
branches: dict[str, list[str]] = Field(default_factory=dict)
|
|
456
460
|
metrics: RolloutMetrics
|
|
457
461
|
aborted: bool = False
|
|
458
462
|
ops_executed: int = 0
|
|
459
463
|
trace: dict[str, Any] | None = None
|
|
464
|
+
pipeline_metadata: dict[str, Any] = Field(default_factory=dict)
|
|
460
465
|
|
|
461
466
|
|
|
462
467
|
class RolloutTracingContext:
|
|
@@ -567,7 +572,7 @@ class RolloutTracingContext:
|
|
|
567
572
|
try:
|
|
568
573
|
await self.tracer.record_message(
|
|
569
574
|
content=self._prompt_payload(entry, role="system"),
|
|
570
|
-
message_type="
|
|
575
|
+
message_type="system", # Use standard message type
|
|
571
576
|
metadata=self._message_metadata(),
|
|
572
577
|
)
|
|
573
578
|
except Exception as exc:
|
|
@@ -576,11 +581,16 @@ class RolloutTracingContext:
|
|
|
576
581
|
try:
|
|
577
582
|
await self.tracer.record_message(
|
|
578
583
|
content=self._prompt_payload(entry, role="user"),
|
|
579
|
-
message_type="
|
|
584
|
+
message_type="user", # Use standard message type
|
|
580
585
|
metadata=self._message_metadata(),
|
|
581
586
|
)
|
|
582
587
|
except Exception as exc:
|
|
583
588
|
logger.debug("TRACING_USER_MSG_FAIL: %s", exc)
|
|
589
|
+
|
|
590
|
+
# Debug: Check message count
|
|
591
|
+
if self.tracer and self.tracer._current_trace:
|
|
592
|
+
msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
|
|
593
|
+
logger.info(f"[TRACE_DEBUG] After record_policy_prompts: {msg_count} messages in trace")
|
|
584
594
|
|
|
585
595
|
def _content_to_text(self, content: Any) -> str:
|
|
586
596
|
if isinstance(content, str):
|
|
@@ -656,8 +666,8 @@ class RolloutTracingContext:
|
|
|
656
666
|
try:
|
|
657
667
|
await self.tracer.record_message(
|
|
658
668
|
content=self._safe_json(tool_calls),
|
|
659
|
-
message_type="
|
|
660
|
-
metadata=self._message_metadata(),
|
|
669
|
+
message_type="assistant", # Map to standard assistant message type
|
|
670
|
+
metadata={**self._message_metadata(), "is_tool_call": True},
|
|
661
671
|
)
|
|
662
672
|
except Exception as exc:
|
|
663
673
|
logger.debug("TRACING_TOOL_MSG_FAIL: %s", exc)
|
|
@@ -928,11 +938,22 @@ class RolloutTracingContext:
|
|
|
928
938
|
except Exception as exc:
|
|
929
939
|
logger.debug("TRACING_OUTCOME_FAIL: %s", exc)
|
|
930
940
|
try:
|
|
941
|
+
# Debug: Check message count before end_session
|
|
942
|
+
if self.tracer._current_trace:
|
|
943
|
+
msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
|
|
944
|
+
logger.info(f"[TRACE_DEBUG] Before end_session: {msg_count} messages in trace")
|
|
945
|
+
|
|
931
946
|
self.session_trace = await self.tracer.end_session()
|
|
932
|
-
|
|
947
|
+
|
|
948
|
+
# Debug: Check if session was saved
|
|
949
|
+
if self.session_trace:
|
|
950
|
+
logger.info(f"[TRACE_DEBUG] Session ended successfully, session_id={self.session_trace.session_id}")
|
|
933
951
|
self.session_trace.metadata.update(self.metadata_updates)
|
|
952
|
+
logger.info(f"[TRACE_DEBUG] session_trace.metadata keys: {list(self.session_trace.metadata.keys())}")
|
|
953
|
+
else:
|
|
954
|
+
logger.warning("[TRACE_DEBUG] end_session returned None!")
|
|
934
955
|
except Exception as exc:
|
|
935
|
-
logger.
|
|
956
|
+
logger.warning(f"TRACING_END_SESSION_FAIL: {exc}", exc_info=True)
|
|
936
957
|
self.session_trace = None
|
|
937
958
|
with contextlib.suppress(Exception):
|
|
938
959
|
await self.tracer.close()
|
|
@@ -1056,12 +1077,14 @@ async def execute_rollout(
|
|
|
1056
1077
|
req: Request,
|
|
1057
1078
|
) -> RolloutResponse:
|
|
1058
1079
|
"""Execute a rollout with coordinated environment and policy steps."""
|
|
1080
|
+
logger.info("ROLLOUT: mode = %s", request.mode)
|
|
1081
|
+
|
|
1059
1082
|
# Emit rollout identifier early for correlation
|
|
1060
1083
|
with contextlib.suppress(Exception):
|
|
1061
1084
|
_rid = getattr(request, "run_id", None)
|
|
1062
1085
|
_pol = getattr(request.policy, "policy_name", None) or getattr(request.policy, "policy_id", None)
|
|
1063
1086
|
_env = getattr(request.env, "env_name", None) or getattr(request.env, "env_id", None)
|
|
1064
|
-
logger.info("ROLLOUT_BEGIN: run_id=%s policy=%s env=%s", _rid, _pol, _env)
|
|
1087
|
+
logger.info("ROLLOUT_BEGIN: run_id=%s policy=%s env=%s mode=%s", _rid, _pol, _env, request.mode)
|
|
1065
1088
|
print(f"[rollout] begin run_id={_rid} policy={_pol} env={_env}", flush=True)
|
|
1066
1089
|
# Enforce per-episode step cap via env-specific parameters; default to 20 if omitted
|
|
1067
1090
|
try:
|
|
@@ -1271,6 +1294,7 @@ async def execute_rollout(
|
|
|
1271
1294
|
config=_policy_config,
|
|
1272
1295
|
rl_run_id=request.run_id,
|
|
1273
1296
|
bound_env_id=env_id,
|
|
1297
|
+
mode=request.mode, # Pass through mode for URL transformation control
|
|
1274
1298
|
),
|
|
1275
1299
|
req,
|
|
1276
1300
|
)
|
|
@@ -1843,14 +1867,73 @@ async def execute_rollout(
|
|
|
1843
1867
|
timing_final.setdefault("overhead_ms", 0.0)
|
|
1844
1868
|
|
|
1845
1869
|
# Build trajectory
|
|
1846
|
-
# Extract inference_url from policy
|
|
1870
|
+
# Extract inference_url from policy config (REQUIRED for trace correlation)
|
|
1871
|
+
# The trainer sets this in policy config with ?cid=... parameter
|
|
1847
1872
|
inference_url = None
|
|
1848
|
-
|
|
1873
|
+
|
|
1874
|
+
# Try policy config from request first (most reliable source)
|
|
1875
|
+
try:
|
|
1876
|
+
policy_config_snapshot = (
|
|
1877
|
+
request.policy.config if isinstance(request.policy.config, dict) else {}
|
|
1878
|
+
)
|
|
1879
|
+
inference_url = policy_config_snapshot.get("inference_url")
|
|
1880
|
+
if inference_url:
|
|
1881
|
+
logger.info(
|
|
1882
|
+
"ROLLOUT_TRAJECTORY: extracted inference_url from request.policy.config run_id=%s url=%s",
|
|
1883
|
+
request.run_id,
|
|
1884
|
+
inference_url,
|
|
1885
|
+
)
|
|
1886
|
+
except Exception as exc:
|
|
1887
|
+
logger.warning(
|
|
1888
|
+
"ROLLOUT_TRAJECTORY: failed to get inference_url from request.policy.config run_id=%s: %s",
|
|
1889
|
+
request.run_id,
|
|
1890
|
+
exc,
|
|
1891
|
+
)
|
|
1892
|
+
|
|
1893
|
+
# Fallback: Try policy handle snapshot (if request.policy.config failed)
|
|
1894
|
+
if not inference_url and policy_handle is not None:
|
|
1849
1895
|
try:
|
|
1850
1896
|
policy_snapshot = policy_handle.snapshot()
|
|
1851
1897
|
inference_url = policy_snapshot.get("config", {}).get("inference_url")
|
|
1852
|
-
|
|
1853
|
-
|
|
1898
|
+
if inference_url:
|
|
1899
|
+
logger.info(
|
|
1900
|
+
"ROLLOUT_TRAJECTORY: extracted inference_url from policy_handle.snapshot run_id=%s url=%s",
|
|
1901
|
+
request.run_id,
|
|
1902
|
+
inference_url,
|
|
1903
|
+
)
|
|
1904
|
+
except Exception as exc:
|
|
1905
|
+
logger.warning(
|
|
1906
|
+
"ROLLOUT_TRAJECTORY: failed to snapshot policy for run_id=%s policy_id=%s: %s",
|
|
1907
|
+
request.run_id,
|
|
1908
|
+
policy_id,
|
|
1909
|
+
exc,
|
|
1910
|
+
)
|
|
1911
|
+
|
|
1912
|
+
# ASSERTION: inference_url MUST be present (required by RolloutTrajectory schema)
|
|
1913
|
+
if not inference_url:
|
|
1914
|
+
raise ValueError(
|
|
1915
|
+
f"FATAL: inference_url is required but not found!\n"
|
|
1916
|
+
f"\n"
|
|
1917
|
+
f"run_id: {request.run_id}\n"
|
|
1918
|
+
f"policy_id: {policy_id}\n"
|
|
1919
|
+
f"policy_config_keys: {list(policy_config_snapshot.keys()) if 'policy_config_snapshot' in locals() else 'N/A'}\n"
|
|
1920
|
+
f"\n"
|
|
1921
|
+
f"The trainer MUST set inference_url in policy config with ?cid=... parameter.\n"
|
|
1922
|
+
f"This is required for trace correlation and hydration.\n"
|
|
1923
|
+
)
|
|
1924
|
+
|
|
1925
|
+
# policy_config_snapshot already set above in try block (line 1876-1878)
|
|
1926
|
+
# Ensure it exists for logging below
|
|
1927
|
+
if 'policy_config_snapshot' not in locals():
|
|
1928
|
+
policy_config_snapshot = {}
|
|
1929
|
+
|
|
1930
|
+
logger.info(
|
|
1931
|
+
"ROLLOUT_TRAJECTORY: run_id=%s policy_id=%s inference_url=%s trace_id=%s",
|
|
1932
|
+
request.run_id,
|
|
1933
|
+
policy_id,
|
|
1934
|
+
inference_url,
|
|
1935
|
+
policy_config_snapshot.get("trace_correlation_id"),
|
|
1936
|
+
)
|
|
1854
1937
|
|
|
1855
1938
|
trajectory = RolloutTrajectory(
|
|
1856
1939
|
env_id=env_id,
|
|
@@ -1948,12 +2031,17 @@ async def execute_rollout(
|
|
|
1948
2031
|
)
|
|
1949
2032
|
finalized = True
|
|
1950
2033
|
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
2034
|
+
|
|
2035
|
+
# Debug: Check trace payload
|
|
2036
|
+
logger.info(f"[TRACE_DEBUG] trace_payload is None: {trace_payload is None}, return_trace={tracing_context.return_trace}")
|
|
2037
|
+
if trace_payload:
|
|
2038
|
+
logger.info(f"[TRACE_DEBUG] trace_payload keys: {list(trace_payload.keys())}")
|
|
1951
2039
|
|
|
1952
2040
|
# Hard-fail if no steps executed (avg_turns == 0 scenario)
|
|
1953
2041
|
if metrics.num_steps <= 0:
|
|
1954
2042
|
raise HTTPException(status_code=500, detail="no_steps_executed: avg_turns == 0")
|
|
1955
2043
|
|
|
1956
|
-
|
|
2044
|
+
response = RolloutResponse(
|
|
1957
2045
|
run_id=request.run_id,
|
|
1958
2046
|
trajectories=[trajectory],
|
|
1959
2047
|
branches={},
|
|
@@ -1962,6 +2050,16 @@ async def execute_rollout(
|
|
|
1962
2050
|
ops_executed=ops_executed,
|
|
1963
2051
|
trace=trace_payload,
|
|
1964
2052
|
)
|
|
2053
|
+
logger.info(
|
|
2054
|
+
"ROLLOUT_RESPONSE: run_id=%s aborted=%s ops_executed=%s metrics_steps=%s trace_present=%s pipeline_metadata=%s",
|
|
2055
|
+
request.run_id,
|
|
2056
|
+
aborted,
|
|
2057
|
+
ops_executed,
|
|
2058
|
+
metrics.num_steps,
|
|
2059
|
+
bool(trace_payload),
|
|
2060
|
+
response.pipeline_metadata,
|
|
2061
|
+
)
|
|
2062
|
+
return response
|
|
1965
2063
|
|
|
1966
2064
|
except Exception as e:
|
|
1967
2065
|
logger.error(f"Rollout failed for run {request.run_id}: {e}")
|
|
@@ -1,9 +1,165 @@
|
|
|
1
1
|
"""Utility functions for the task service."""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
from typing import Any
|
|
5
|
+
from urllib.parse import parse_qs, urlparse, urlunparse
|
|
4
6
|
|
|
5
7
|
import numpy as np
|
|
6
8
|
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
_CHAT_COMPLETIONS_SUFFIX = "/v1/chat/completions"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def ensure_chat_completions_url(raw_url: Any, mode: str | None = None) -> Any:
|
|
15
|
+
"""
|
|
16
|
+
Ensure inference URLs point at the chat completions endpoint.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
raw_url: The inference URL to process
|
|
20
|
+
mode: "rl" applies URL transformations, "eval" uses URLs as-is (deprecated - use RolloutMode enum)
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Processed URL (transformed in RL mode, unchanged in EVAL mode)
|
|
24
|
+
"""
|
|
25
|
+
# In EVAL mode, use URLs exactly as provided - no transformations
|
|
26
|
+
# Accept both string "eval" (legacy) and RolloutMode.EVAL
|
|
27
|
+
from synth_ai.task.contracts import RolloutMode
|
|
28
|
+
is_eval_mode = (mode == "eval" or mode == RolloutMode.EVAL or
|
|
29
|
+
(hasattr(mode, 'value') and mode.value == "eval"))
|
|
30
|
+
|
|
31
|
+
if is_eval_mode:
|
|
32
|
+
logger.info("ensure_chat_completions_url: EVAL mode - using URL as-is: %s", raw_url)
|
|
33
|
+
return raw_url
|
|
34
|
+
|
|
35
|
+
# RL mode: apply transformations for compatibility
|
|
36
|
+
if not isinstance(raw_url, str):
|
|
37
|
+
logger.debug("ensure_chat_completions_url: non-string input %r (type=%s)", raw_url, type(raw_url))
|
|
38
|
+
return raw_url
|
|
39
|
+
url = raw_url.strip()
|
|
40
|
+
if not url:
|
|
41
|
+
logger.debug("ensure_chat_completions_url: blank/whitespace URL input")
|
|
42
|
+
return raw_url
|
|
43
|
+
|
|
44
|
+
parsed = urlparse(url)
|
|
45
|
+
path = (parsed.path or "").rstrip("/")
|
|
46
|
+
if path.endswith("/v1/chat/completions"):
|
|
47
|
+
logger.debug("ensure_chat_completions_url: URL already normalized %s", url)
|
|
48
|
+
# Already targeting the desired endpoint; keep original to preserve trailing slash.
|
|
49
|
+
return url
|
|
50
|
+
|
|
51
|
+
if not path:
|
|
52
|
+
new_path = _CHAT_COMPLETIONS_SUFFIX
|
|
53
|
+
else:
|
|
54
|
+
new_path = f"{path}{_CHAT_COMPLETIONS_SUFFIX}"
|
|
55
|
+
|
|
56
|
+
rebuilt = parsed._replace(path=new_path)
|
|
57
|
+
normalized = urlunparse(rebuilt)
|
|
58
|
+
logger.info(
|
|
59
|
+
"ensure_chat_completions_url: RL mode - normalized inference URL from %s to %s",
|
|
60
|
+
url,
|
|
61
|
+
normalized,
|
|
62
|
+
)
|
|
63
|
+
return normalized
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def inference_url_to_trace_correlation_id(raw_url: Any, *, required: bool = False, mode: Any = None) -> str | None:
|
|
67
|
+
"""
|
|
68
|
+
Extract trace_correlation_id from inference URL query params.
|
|
69
|
+
|
|
70
|
+
The inference URL should contain ?cid=trace_xxxxx parameter.
|
|
71
|
+
This is THE canonical source for trace_correlation_id - it's what the
|
|
72
|
+
inference server uses to tag traces, so we extract it here.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
raw_url: Inference URL (should contain ?cid=... query param)
|
|
76
|
+
required: If True, raises AssertionError if trace_correlation_id not found
|
|
77
|
+
mode: RolloutMode or string ("rl" or "eval"). Controls warning behavior -
|
|
78
|
+
warnings only logged for RL mode, not EVAL mode.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
trace_correlation_id if found in URL, None otherwise
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
AssertionError: If required=True and trace_correlation_id not found
|
|
85
|
+
"""
|
|
86
|
+
if not isinstance(raw_url, str):
|
|
87
|
+
logger.debug(
|
|
88
|
+
"inference_url_to_trace_correlation_id: non-string input %r (type=%s)",
|
|
89
|
+
raw_url,
|
|
90
|
+
type(raw_url)
|
|
91
|
+
)
|
|
92
|
+
if required:
|
|
93
|
+
raise AssertionError(
|
|
94
|
+
f"FATAL: inference_url_to_trace_correlation_id requires string URL, got {type(raw_url)}: {raw_url!r}"
|
|
95
|
+
)
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
parsed = urlparse(raw_url)
|
|
99
|
+
query_params = parse_qs(parsed.query or "")
|
|
100
|
+
|
|
101
|
+
# Check all possible parameter names (cid is primary)
|
|
102
|
+
candidates = (
|
|
103
|
+
query_params.get("cid") or
|
|
104
|
+
query_params.get("trace") or
|
|
105
|
+
query_params.get("trace_correlation_id") or
|
|
106
|
+
[]
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
for value in candidates:
|
|
110
|
+
if isinstance(value, str) and value.strip():
|
|
111
|
+
correlation_id = value.strip()
|
|
112
|
+
logger.info(
|
|
113
|
+
"inference_url_to_trace_correlation_id: ✅ extracted id=%s from url=%s",
|
|
114
|
+
correlation_id,
|
|
115
|
+
raw_url,
|
|
116
|
+
)
|
|
117
|
+
# ASSERTION: Correlation ID should look like trace_xxxxx
|
|
118
|
+
assert correlation_id.startswith("trace_"), (
|
|
119
|
+
f"FATAL: trace_correlation_id has unexpected format: {correlation_id!r}. "
|
|
120
|
+
f"Expected to start with 'trace_'"
|
|
121
|
+
)
|
|
122
|
+
return correlation_id
|
|
123
|
+
|
|
124
|
+
# Not found - check if we're in EVAL mode (trace_correlation_id not required for eval)
|
|
125
|
+
from synth_ai.task.contracts import RolloutMode
|
|
126
|
+
is_eval_mode = (mode == "eval" or mode == RolloutMode.EVAL or
|
|
127
|
+
(hasattr(mode, 'value') and mode.value == "eval"))
|
|
128
|
+
|
|
129
|
+
if is_eval_mode:
|
|
130
|
+
# For EVAL mode, missing trace_correlation_id is expected - log as debug, not warning
|
|
131
|
+
logger.debug(
|
|
132
|
+
"inference_url_to_trace_correlation_id: No trace_correlation_id in EVAL mode (expected) url=%s query_params=%s",
|
|
133
|
+
raw_url,
|
|
134
|
+
list(query_params.keys())
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
# For RL mode, missing trace_correlation_id is concerning
|
|
138
|
+
logger.warning(
|
|
139
|
+
"inference_url_to_trace_correlation_id: ❌ NO trace_correlation_id found in url=%s query_params=%s",
|
|
140
|
+
raw_url,
|
|
141
|
+
list(query_params.keys())
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if required:
|
|
145
|
+
raise AssertionError(
|
|
146
|
+
f"FATAL: trace_correlation_id REQUIRED but not found in inference_url!\n"
|
|
147
|
+
f"\n"
|
|
148
|
+
f"URL: {raw_url}\n"
|
|
149
|
+
f"Query params found: {list(query_params.keys())}\n"
|
|
150
|
+
f"\n"
|
|
151
|
+
f"The inference_url MUST contain ?cid=trace_xxxxx parameter.\n"
|
|
152
|
+
f"This is set by the trainer when generating rollout requests.\n"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# Legacy alias for backward compatibility
|
|
159
|
+
def extract_trace_correlation_id(raw_url: Any, mode: Any = None) -> str | None:
|
|
160
|
+
"""DEPRECATED: Use inference_url_to_trace_correlation_id instead."""
|
|
161
|
+
return inference_url_to_trace_correlation_id(raw_url, required=False, mode=mode)
|
|
162
|
+
|
|
7
163
|
|
|
8
164
|
def convert_numpy_to_python(obj: Any) -> Any:
|
|
9
165
|
"""
|