synth-ai 0.2.13.dev2__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +5 -4
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
- examples/sft/evaluate.py +2 -0
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +1 -0
- examples/swe/task_app/hosted/rollout.py +2 -0
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +3 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +306 -8
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +16 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +52 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +111 -13
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +156 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +2 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +2 -0
- examples/task_apps/pokemon_red/task_app.py +199 -6
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +2 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +8 -4
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +258 -23
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +2 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- synth_ai/api/models/supported.py +1 -0
- synth_ai/cli/__init__.py +46 -13
- synth_ai/cli/_modal_wrapper.py +3 -2
- synth_ai/cli/recent.py +1 -1
- synth_ai/cli/status.py +1 -1
- synth_ai/cli/task_apps.py +354 -143
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/tui.py +57 -0
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/verilog/engine.py +76 -10
- synth_ai/judge_schemas.py +8 -8
- synth_ai/task/__init__.py +11 -1
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +257 -0
- synth_ai/task/contracts.py +15 -2
- synth_ai/task/rubrics/__init__.py +3 -0
- synth_ai/task/rubrics/loaders.py +22 -3
- synth_ai/task/rubrics/scoring.py +3 -0
- synth_ai/task/trace_correlation_helpers.py +315 -0
- synth_ai/task/validators.py +144 -0
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/session_tracer.py +16 -6
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +8 -7
- synth_ai/tracing_v3/turso/native_manager.py +63 -40
- synth_ai/tracing_v3/utils.py +3 -3
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +906 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/METADATA +1 -1
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/RECORD +110 -71
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
synth_ai/task/contracts.py
CHANGED
|
@@ -1,11 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
4
5
|
from typing import Any, Literal
|
|
5
6
|
|
|
6
7
|
from pydantic import BaseModel, ConfigDict, Field
|
|
7
8
|
|
|
8
9
|
|
|
10
|
+
class RolloutMode(str, Enum):
|
|
11
|
+
"""Mode controls how rollout infrastructure processes inference URLs."""
|
|
12
|
+
RL = "rl"
|
|
13
|
+
EVAL = "eval"
|
|
14
|
+
|
|
15
|
+
|
|
9
16
|
@dataclass(frozen=True)
|
|
10
17
|
class TaskAppEndpoints:
|
|
11
18
|
"""Required Task App endpoints used by RL trainers and clients.
|
|
@@ -43,7 +50,7 @@ class RolloutRecordConfig(BaseModel):
|
|
|
43
50
|
logprobs: bool = False
|
|
44
51
|
value: bool = False
|
|
45
52
|
return_trace: bool = False
|
|
46
|
-
trace_format: Literal["compact", "full"] = "compact"
|
|
53
|
+
trace_format: Literal["compact", "full", "structured"] = "compact"
|
|
47
54
|
|
|
48
55
|
|
|
49
56
|
class RolloutSafetyConfig(BaseModel):
|
|
@@ -61,6 +68,7 @@ class RolloutRequest(BaseModel):
|
|
|
61
68
|
safety: RolloutSafetyConfig = RolloutSafetyConfig()
|
|
62
69
|
training_session_id: str | None = None
|
|
63
70
|
synth_base_url: str | None = None
|
|
71
|
+
mode: RolloutMode # Required: explicit RL vs EVAL mode
|
|
64
72
|
|
|
65
73
|
|
|
66
74
|
class RolloutStep(BaseModel):
|
|
@@ -110,7 +118,7 @@ class RolloutTrajectory(BaseModel):
|
|
|
110
118
|
|
|
111
119
|
# Required for trace correlation with inference mesh (optional initially for backward compat)
|
|
112
120
|
# See: monorepo/INFERENCE_URL_REQUIREMENT_PLAN.md and trace_creation_and_judgement.txt
|
|
113
|
-
inference_url: str
|
|
121
|
+
inference_url: str
|
|
114
122
|
|
|
115
123
|
decision_samples: list[dict[str, Any]] | None = None
|
|
116
124
|
|
|
@@ -143,10 +151,15 @@ class RolloutResponse(BaseModel):
|
|
|
143
151
|
aborted: bool = False
|
|
144
152
|
ops_executed: int = 0
|
|
145
153
|
|
|
154
|
+
# OPTIONAL: correlation ID for linking rollout to inference traces
|
|
155
|
+
# If not provided, trainer will infer it from trajectory.inference_url ?cid=... parameter
|
|
156
|
+
trace_correlation_id: str | None = None
|
|
157
|
+
|
|
146
158
|
# PREFERRED: v3 trace format (SessionTrace). This is the single source of truth
|
|
147
159
|
# for rollout data and should be used by all new code. Contains richer data than
|
|
148
160
|
# trajectories including token IDs, logprobs, timing, and multimodal content.
|
|
149
161
|
trace: dict[str, Any] | None = None
|
|
162
|
+
pipeline_metadata: dict[str, Any] = Field(default_factory=dict)
|
|
150
163
|
|
|
151
164
|
|
|
152
165
|
class _ExtraAllowModel(BaseModel):
|
synth_ai/task/rubrics/loaders.py
CHANGED
|
@@ -60,15 +60,34 @@ def load_rubric(source: str | dict[str, Any] | Rubric | None) -> Rubric | None:
|
|
|
60
60
|
|
|
61
61
|
Returns:
|
|
62
62
|
Parsed Rubric instance or None if source is None
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
ValueError: If the rubric format is incorrect (e.g., backend judge format)
|
|
66
|
+
ValidationError: If the rubric fails schema validation
|
|
63
67
|
"""
|
|
64
68
|
if source is None:
|
|
65
69
|
return None
|
|
66
70
|
if isinstance(source, Rubric):
|
|
67
71
|
return source
|
|
72
|
+
|
|
73
|
+
# Load and parse the data
|
|
68
74
|
if isinstance(source, dict):
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
75
|
+
data = source
|
|
76
|
+
else:
|
|
77
|
+
text, suffix = _load_text(str(source))
|
|
78
|
+
data = _parse_structured(text, suffix)
|
|
79
|
+
|
|
80
|
+
# Check if this looks like a backend judge rubric (wrong format)
|
|
81
|
+
if isinstance(data, dict) and "event" in data and "outcome" in data:
|
|
82
|
+
# Missing required task app rubric fields
|
|
83
|
+
if "version" not in data and "goal_text" not in data and "criteria" not in data:
|
|
84
|
+
source_hint = f" ({source})" if isinstance(source, str) else ""
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"Rubric appears to be in backend judge format (has 'event'/'outcome' keys){source_hint}. "
|
|
87
|
+
f"Task apps require rubrics with 'version', 'goal_text', and 'criteria' fields. "
|
|
88
|
+
f"Backend judge rubrics should be named '*_backend_judge.json' and loaded by judge functions."
|
|
89
|
+
)
|
|
90
|
+
|
|
72
91
|
return Rubric.model_validate(data)
|
|
73
92
|
|
|
74
93
|
|
synth_ai/task/rubrics/scoring.py
CHANGED
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""Helpers for trace correlation ID extraction and inclusion in task apps.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for task apps to:
|
|
4
|
+
1. Extract trace_correlation_id from rollout requests
|
|
5
|
+
2. Include trace_correlation_id in rollout responses (3 required locations)
|
|
6
|
+
|
|
7
|
+
See monorepo/trace_creation_and_judgement.txt "Fatal Guards" section for requirements.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Any
|
|
12
|
+
from urllib.parse import parse_qs, urlparse
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def extract_trace_correlation_id(
|
|
18
|
+
policy_config: dict[str, Any],
|
|
19
|
+
inference_url: str | None = None,
|
|
20
|
+
mode: Any = None
|
|
21
|
+
) -> str | None:
|
|
22
|
+
"""
|
|
23
|
+
Extract trace_correlation_id from policy config or inference URL.
|
|
24
|
+
|
|
25
|
+
This is the standardized method for all task apps to extract the correlation ID
|
|
26
|
+
that the RL trainer generates and passes to the task app.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
policy_config: Policy configuration dict from RolloutRequest.policy.config
|
|
30
|
+
inference_url: Inference URL (optional, used as fallback)
|
|
31
|
+
mode: RolloutMode or string ("rl" or "eval"). Controls warning behavior -
|
|
32
|
+
warnings only logged for RL mode, not EVAL mode.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
trace_correlation_id if found, None otherwise
|
|
36
|
+
|
|
37
|
+
Extraction order:
|
|
38
|
+
1. policy_config["trace_correlation_id"] (preferred)
|
|
39
|
+
2. policy_config["trace"] (legacy fallback)
|
|
40
|
+
3. URL query param ?cid=... (fallback)
|
|
41
|
+
4. URL query param ?trace_correlation_id=... (fallback)
|
|
42
|
+
"""
|
|
43
|
+
# Try policy_config first (preferred method)
|
|
44
|
+
candidates: list[Any] = [
|
|
45
|
+
policy_config.get("trace_correlation_id"),
|
|
46
|
+
policy_config.get("trace"),
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
logger.debug(
|
|
50
|
+
"extract_trace_correlation_id: policy_cfg keys=%s candidates=%s",
|
|
51
|
+
sorted(policy_config.keys()),
|
|
52
|
+
candidates,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
for candidate in candidates:
|
|
56
|
+
if isinstance(candidate, str):
|
|
57
|
+
stripped = candidate.strip()
|
|
58
|
+
if stripped:
|
|
59
|
+
logger.info(
|
|
60
|
+
"extract_trace_correlation_id: extracted from policy_config=%s",
|
|
61
|
+
stripped
|
|
62
|
+
)
|
|
63
|
+
return stripped
|
|
64
|
+
|
|
65
|
+
# Determine if we're in EVAL mode (trace_correlation_id not required for eval)
|
|
66
|
+
try:
|
|
67
|
+
from synth_ai.task.contracts import RolloutMode
|
|
68
|
+
is_eval_mode = (mode == "eval" or mode == RolloutMode.EVAL or
|
|
69
|
+
(hasattr(mode, 'value') and mode.value == "eval"))
|
|
70
|
+
except ImportError:
|
|
71
|
+
# If RolloutMode not available, fall back to string comparison
|
|
72
|
+
is_eval_mode = (mode == "eval")
|
|
73
|
+
|
|
74
|
+
# Fallback: try to extract from inference_url query params
|
|
75
|
+
if not inference_url or not isinstance(inference_url, str):
|
|
76
|
+
if is_eval_mode:
|
|
77
|
+
logger.debug(
|
|
78
|
+
"extract_trace_correlation_id: no correlation ID found in policy_config "
|
|
79
|
+
"and no inference_url provided (EVAL mode - expected)"
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
logger.warning(
|
|
83
|
+
"extract_trace_correlation_id: no correlation ID found in policy_config "
|
|
84
|
+
"and no inference_url provided"
|
|
85
|
+
)
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
parsed = urlparse(inference_url)
|
|
90
|
+
query_params = parse_qs(parsed.query or "")
|
|
91
|
+
# Try multiple possible query param names
|
|
92
|
+
for param_name in ["cid", "trace_correlation_id", "trace"]:
|
|
93
|
+
values = query_params.get(param_name, [])
|
|
94
|
+
for value in values:
|
|
95
|
+
if isinstance(value, str) and value.strip():
|
|
96
|
+
correlation_id = value.strip()
|
|
97
|
+
logger.info(
|
|
98
|
+
"extract_trace_correlation_id: extracted from URL param %s=%s",
|
|
99
|
+
param_name,
|
|
100
|
+
correlation_id,
|
|
101
|
+
)
|
|
102
|
+
return correlation_id
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.warning(
|
|
105
|
+
"extract_trace_correlation_id: failed to parse inference_url=%s error=%s",
|
|
106
|
+
inference_url,
|
|
107
|
+
e,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
if is_eval_mode:
|
|
111
|
+
logger.debug(
|
|
112
|
+
"extract_trace_correlation_id: no trace_correlation_id found in "
|
|
113
|
+
"policy_config or inference_url=%s (EVAL mode - expected)",
|
|
114
|
+
inference_url,
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
logger.warning(
|
|
118
|
+
"extract_trace_correlation_id: no trace_correlation_id found in "
|
|
119
|
+
"policy_config or inference_url=%s",
|
|
120
|
+
inference_url,
|
|
121
|
+
)
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def validate_trace_correlation_id(
|
|
126
|
+
trace_correlation_id: str | None,
|
|
127
|
+
run_id: str,
|
|
128
|
+
policy_config: dict[str, Any],
|
|
129
|
+
fatal: bool = False
|
|
130
|
+
) -> str | None:
|
|
131
|
+
"""
|
|
132
|
+
Validate that trace_correlation_id was successfully extracted.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
trace_correlation_id: The extracted correlation ID (or None)
|
|
136
|
+
run_id: Rollout run_id for logging
|
|
137
|
+
policy_config: Policy configuration for debugging
|
|
138
|
+
fatal: If True, raise ValueError on missing ID. If False, log error only.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
trace_correlation_id if present, None if missing (when fatal=False)
|
|
142
|
+
|
|
143
|
+
Raises:
|
|
144
|
+
ValueError: If trace_correlation_id is missing and fatal=True
|
|
145
|
+
"""
|
|
146
|
+
if not trace_correlation_id:
|
|
147
|
+
error_msg = (
|
|
148
|
+
f"🚨 CRITICAL: Cannot extract trace_correlation_id!\n"
|
|
149
|
+
"\n"
|
|
150
|
+
f"Run ID: {run_id}\n"
|
|
151
|
+
f"Policy config keys: {sorted(policy_config.keys())}\n"
|
|
152
|
+
f"Inference URL: {policy_config.get('inference_url', 'NOT_SET')}\n"
|
|
153
|
+
"\n"
|
|
154
|
+
"Checked:\n"
|
|
155
|
+
f"1. policy_config['trace_correlation_id']: {policy_config.get('trace_correlation_id')}\n"
|
|
156
|
+
f"2. policy_config['trace']: {policy_config.get('trace')}\n"
|
|
157
|
+
f"3. inference_url query params\n"
|
|
158
|
+
"\n"
|
|
159
|
+
"Task app CANNOT proceed without trace_correlation_id.\n"
|
|
160
|
+
"This indicates the RL trainer is not sending it correctly.\n"
|
|
161
|
+
"\n"
|
|
162
|
+
"See monorepo/trace_creation_and_judgement.txt 'Fatal Guards' section.\n"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if fatal:
|
|
166
|
+
raise ValueError(error_msg)
|
|
167
|
+
else:
|
|
168
|
+
logger.error(error_msg)
|
|
169
|
+
|
|
170
|
+
return trace_correlation_id
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def include_trace_correlation_id_in_response(
|
|
174
|
+
response_data: dict[str, Any],
|
|
175
|
+
trace_correlation_id: str | None,
|
|
176
|
+
run_id: str
|
|
177
|
+
) -> dict[str, Any]:
|
|
178
|
+
"""
|
|
179
|
+
Include trace_correlation_id in all required locations of rollout response.
|
|
180
|
+
|
|
181
|
+
Required locations (per Fatal Guards section):
|
|
182
|
+
1. Top-level response["trace_correlation_id"]
|
|
183
|
+
2. response["pipeline_metadata"]["trace_correlation_id"]
|
|
184
|
+
3. Each trajectory["trace_correlation_id"]
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
response_data: RolloutResponse dict (from .model_dump())
|
|
188
|
+
trace_correlation_id: The correlation ID to include
|
|
189
|
+
run_id: Rollout run_id for logging
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
Modified response_data with trace_correlation_id in all required places
|
|
193
|
+
"""
|
|
194
|
+
if not trace_correlation_id:
|
|
195
|
+
logger.error(
|
|
196
|
+
"include_trace_correlation_id_in_response: missing trace_correlation_id "
|
|
197
|
+
"for run_id=%s - cannot include in response",
|
|
198
|
+
run_id
|
|
199
|
+
)
|
|
200
|
+
return response_data
|
|
201
|
+
|
|
202
|
+
# 1. Add to top-level (REQUIRED)
|
|
203
|
+
if "trace_correlation_id" not in response_data:
|
|
204
|
+
response_data["trace_correlation_id"] = trace_correlation_id
|
|
205
|
+
logger.info(
|
|
206
|
+
"include_trace_correlation_id: added to top-level run_id=%s cid=%s",
|
|
207
|
+
run_id,
|
|
208
|
+
trace_correlation_id
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# 2. Add to pipeline_metadata (REQUIRED)
|
|
212
|
+
pipeline_meta = response_data.get("pipeline_metadata")
|
|
213
|
+
if not isinstance(pipeline_meta, dict):
|
|
214
|
+
pipeline_meta = {}
|
|
215
|
+
response_data["pipeline_metadata"] = pipeline_meta
|
|
216
|
+
|
|
217
|
+
if "trace_correlation_id" not in pipeline_meta:
|
|
218
|
+
pipeline_meta["trace_correlation_id"] = trace_correlation_id
|
|
219
|
+
logger.info(
|
|
220
|
+
"include_trace_correlation_id: added to pipeline_metadata run_id=%s cid=%s",
|
|
221
|
+
run_id,
|
|
222
|
+
trace_correlation_id
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# 3. Add to each trajectory (REQUIRED)
|
|
226
|
+
trajectories = response_data.get("trajectories", [])
|
|
227
|
+
if isinstance(trajectories, list):
|
|
228
|
+
for idx, traj in enumerate(trajectories):
|
|
229
|
+
if isinstance(traj, dict) and "trace_correlation_id" not in traj:
|
|
230
|
+
traj["trace_correlation_id"] = trace_correlation_id
|
|
231
|
+
logger.debug(
|
|
232
|
+
"include_trace_correlation_id: added to trajectory[%d] run_id=%s cid=%s",
|
|
233
|
+
idx,
|
|
234
|
+
run_id,
|
|
235
|
+
trace_correlation_id
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
logger.info(
|
|
239
|
+
"include_trace_correlation_id: completed run_id=%s cid=%s "
|
|
240
|
+
"added to %d locations (top-level, metadata, %d trajectories)",
|
|
241
|
+
run_id,
|
|
242
|
+
trace_correlation_id,
|
|
243
|
+
2 + len(trajectories),
|
|
244
|
+
len(trajectories)
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
return response_data
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def verify_trace_correlation_id_in_response(
|
|
251
|
+
response_data: dict[str, Any],
|
|
252
|
+
expected_correlation_id: str | None,
|
|
253
|
+
run_id: str
|
|
254
|
+
) -> bool:
|
|
255
|
+
"""
|
|
256
|
+
Verify that trace_correlation_id is present in all required locations.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
response_data: RolloutResponse dict to verify
|
|
260
|
+
expected_correlation_id: The correlation ID that should be present
|
|
261
|
+
run_id: Rollout run_id for logging
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
True if all required locations have the correlation ID, False otherwise
|
|
265
|
+
"""
|
|
266
|
+
if not expected_correlation_id:
|
|
267
|
+
logger.error(
|
|
268
|
+
"verify_trace_correlation_id: no expected_correlation_id provided for run_id=%s",
|
|
269
|
+
run_id
|
|
270
|
+
)
|
|
271
|
+
return False
|
|
272
|
+
|
|
273
|
+
errors = []
|
|
274
|
+
|
|
275
|
+
# Check top-level
|
|
276
|
+
if response_data.get("trace_correlation_id") != expected_correlation_id:
|
|
277
|
+
errors.append(
|
|
278
|
+
f"Top-level missing or mismatch: "
|
|
279
|
+
f"expected={expected_correlation_id} actual={response_data.get('trace_correlation_id')}"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Check pipeline_metadata
|
|
283
|
+
pipeline_meta = response_data.get("pipeline_metadata", {})
|
|
284
|
+
if not isinstance(pipeline_meta, dict) or pipeline_meta.get("trace_correlation_id") != expected_correlation_id:
|
|
285
|
+
errors.append(
|
|
286
|
+
f"pipeline_metadata missing or mismatch: "
|
|
287
|
+
f"expected={expected_correlation_id} actual={pipeline_meta.get('trace_correlation_id') if isinstance(pipeline_meta, dict) else 'NOT_A_DICT'}"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Check trajectories
|
|
291
|
+
trajectories = response_data.get("trajectories", [])
|
|
292
|
+
if isinstance(trajectories, list):
|
|
293
|
+
for idx, traj in enumerate(trajectories):
|
|
294
|
+
if isinstance(traj, dict) and traj.get("trace_correlation_id") != expected_correlation_id:
|
|
295
|
+
errors.append(
|
|
296
|
+
f"trajectory[{idx}] missing or mismatch: "
|
|
297
|
+
f"expected={expected_correlation_id} actual={traj.get('trace_correlation_id')}"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
if errors:
|
|
301
|
+
logger.error(
|
|
302
|
+
"verify_trace_correlation_id: FAILED run_id=%s\n%s",
|
|
303
|
+
run_id,
|
|
304
|
+
"\n".join(errors)
|
|
305
|
+
)
|
|
306
|
+
return False
|
|
307
|
+
|
|
308
|
+
logger.info(
|
|
309
|
+
"verify_trace_correlation_id: PASSED run_id=%s cid=%s",
|
|
310
|
+
run_id,
|
|
311
|
+
expected_correlation_id
|
|
312
|
+
)
|
|
313
|
+
return True
|
|
314
|
+
|
|
315
|
+
|
synth_ai/task/validators.py
CHANGED
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
6
|
from typing import Any
|
|
7
|
+
from urllib.parse import urlparse, urlunparse
|
|
7
8
|
|
|
8
9
|
import click
|
|
9
10
|
import httpx
|
|
@@ -11,6 +12,149 @@ import httpx
|
|
|
11
12
|
from synth_ai.task.contracts import TaskAppEndpoints # type: ignore[attr-defined]
|
|
12
13
|
|
|
13
14
|
|
|
15
|
+
def validate_rollout_response_for_rl(response_data: dict[str, Any], *, warn_only: bool = False) -> list[str]:
|
|
16
|
+
"""Validate that a task app rollout response has required fields for RL training.
|
|
17
|
+
|
|
18
|
+
The backend RL trainer requires:
|
|
19
|
+
1. pipeline_metadata["inference_url"] at top level (with ?cid= for trace correlation)
|
|
20
|
+
2. Each step's info.meta["inference_url"] must be present (nested structure!)
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
response_data: The rollout response dict from task app
|
|
24
|
+
warn_only: If True, return warnings instead of raising exceptions
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
List of validation warnings/errors
|
|
28
|
+
|
|
29
|
+
Raises:
|
|
30
|
+
ValueError: If critical fields are missing (unless warn_only=True)
|
|
31
|
+
"""
|
|
32
|
+
issues = []
|
|
33
|
+
|
|
34
|
+
# Check pipeline_metadata
|
|
35
|
+
pipeline_metadata = response_data.get("pipeline_metadata")
|
|
36
|
+
if not isinstance(pipeline_metadata, dict):
|
|
37
|
+
issues.append("Missing or invalid 'pipeline_metadata' (required for RL training)")
|
|
38
|
+
else:
|
|
39
|
+
inference_url = pipeline_metadata.get("inference_url")
|
|
40
|
+
if not inference_url:
|
|
41
|
+
issues.append(
|
|
42
|
+
"pipeline_metadata['inference_url'] is missing. "
|
|
43
|
+
"RL trainer requires this field to extract traces."
|
|
44
|
+
)
|
|
45
|
+
elif not isinstance(inference_url, str):
|
|
46
|
+
issues.append(
|
|
47
|
+
f"pipeline_metadata['inference_url'] must be a string, got: {type(inference_url).__name__}"
|
|
48
|
+
)
|
|
49
|
+
elif "?cid=" not in inference_url:
|
|
50
|
+
issues.append(
|
|
51
|
+
f"pipeline_metadata['inference_url'] should contain '?cid=' for trace correlation. "
|
|
52
|
+
f"Got: {inference_url[:80]}..."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Check trajectories and steps
|
|
56
|
+
trajectories = response_data.get("trajectories", [])
|
|
57
|
+
if not trajectories:
|
|
58
|
+
issues.append("No trajectories found in response")
|
|
59
|
+
|
|
60
|
+
for traj_idx, trajectory in enumerate(trajectories):
|
|
61
|
+
if not isinstance(trajectory, dict):
|
|
62
|
+
continue
|
|
63
|
+
|
|
64
|
+
steps = trajectory.get("steps", [])
|
|
65
|
+
for step_idx, step in enumerate(steps):
|
|
66
|
+
if not isinstance(step, dict):
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
step_info = step.get("info", {})
|
|
70
|
+
if not isinstance(step_info, dict):
|
|
71
|
+
issues.append(
|
|
72
|
+
f"trajectory[{traj_idx}].steps[{step_idx}].info is not a dict"
|
|
73
|
+
)
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
# Check for nested meta.inference_url (backend expects this structure!)
|
|
77
|
+
step_meta = step_info.get("meta", {})
|
|
78
|
+
if not isinstance(step_meta, dict):
|
|
79
|
+
issues.append(
|
|
80
|
+
f"trajectory[{traj_idx}].steps[{step_idx}].info.meta is missing or not a dict. "
|
|
81
|
+
f"RL trainer expects nested structure: info.meta.inference_url"
|
|
82
|
+
)
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
step_inference_url = step_meta.get("inference_url")
|
|
86
|
+
if not step_inference_url:
|
|
87
|
+
issues.append(
|
|
88
|
+
f"trajectory[{traj_idx}].steps[{step_idx}].info.meta['inference_url'] is missing. "
|
|
89
|
+
f"RL trainer needs this for trace extraction (nested structure required!)"
|
|
90
|
+
)
|
|
91
|
+
elif not isinstance(step_inference_url, str):
|
|
92
|
+
issues.append(
|
|
93
|
+
f"trajectory[{traj_idx}].steps[{step_idx}].info.meta['inference_url'] must be a string, "
|
|
94
|
+
f"got: {type(step_inference_url).__name__}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if issues and not warn_only:
|
|
98
|
+
error_msg = "Task app response validation failed for RL training:\n" + "\n".join(
|
|
99
|
+
f" - {issue}" for issue in issues
|
|
100
|
+
)
|
|
101
|
+
raise ValueError(error_msg)
|
|
102
|
+
|
|
103
|
+
return issues
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def normalize_inference_url(url: str | None, *, default: str = "https://api.openai.com/v1/chat/completions") -> str:
|
|
107
|
+
"""Normalize an inference URL to include the /v1/chat/completions path.
|
|
108
|
+
|
|
109
|
+
This utility ensures inference URLs have the correct path structure for OpenAI-compatible
|
|
110
|
+
chat completions endpoints, while preserving query parameters (e.g., ?cid=trace_123)
|
|
111
|
+
that may be added for tracing.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
url: The inference URL to normalize (may be None or incomplete)
|
|
115
|
+
default: Default URL to use if url is None/empty
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Normalized URL with proper path and preserved query parameters
|
|
119
|
+
|
|
120
|
+
Examples:
|
|
121
|
+
>>> normalize_inference_url("https://api.groq.com")
|
|
122
|
+
'https://api.groq.com/v1/chat/completions'
|
|
123
|
+
|
|
124
|
+
>>> normalize_inference_url("https://modal.host?cid=trace_123")
|
|
125
|
+
'https://modal.host/v1/chat/completions?cid=trace_123'
|
|
126
|
+
|
|
127
|
+
>>> normalize_inference_url("https://api.openai.com/v1")
|
|
128
|
+
'https://api.openai.com/v1/chat/completions'
|
|
129
|
+
|
|
130
|
+
>>> normalize_inference_url("https://api.groq.com/openai/v1/chat/completions")
|
|
131
|
+
'https://api.groq.com/openai/v1/chat/completions'
|
|
132
|
+
"""
|
|
133
|
+
candidate = (url or default).strip()
|
|
134
|
+
if not candidate:
|
|
135
|
+
candidate = default
|
|
136
|
+
|
|
137
|
+
# Parse the URL to separate path and query components
|
|
138
|
+
parsed = urlparse(candidate)
|
|
139
|
+
|
|
140
|
+
# Check if path already ends with a completions endpoint
|
|
141
|
+
path = parsed.path.rstrip('/')
|
|
142
|
+
if path.endswith("/v1/chat/completions") or path.endswith("/chat/completions"):
|
|
143
|
+
return candidate
|
|
144
|
+
|
|
145
|
+
# Determine what to append based on existing path
|
|
146
|
+
if path.endswith("/v1"):
|
|
147
|
+
new_path = f"{path}/chat/completions"
|
|
148
|
+
elif path.endswith("/chat"):
|
|
149
|
+
new_path = f"{path}/completions"
|
|
150
|
+
else:
|
|
151
|
+
# Default: append full path
|
|
152
|
+
new_path = f"{path}/v1/chat/completions" if path else "/v1/chat/completions"
|
|
153
|
+
|
|
154
|
+
# Reconstruct URL with new path and original query/fragment
|
|
155
|
+
return urlunparse(parsed._replace(path=new_path))
|
|
156
|
+
|
|
157
|
+
|
|
14
158
|
def validate_task_app_url(url: str | None) -> str:
|
|
15
159
|
"""Validate and normalize a task app URL.
|
|
16
160
|
|
|
@@ -37,7 +37,7 @@ Concepts:
|
|
|
37
37
|
from __future__ import annotations
|
|
38
38
|
|
|
39
39
|
from dataclasses import asdict, dataclass, field
|
|
40
|
-
from datetime import
|
|
40
|
+
from datetime import datetime, timezone
|
|
41
41
|
from typing import Any
|
|
42
42
|
|
|
43
43
|
from .lm_call_record_abstractions import LLMCallRecord
|
|
@@ -249,7 +249,7 @@ class SessionTimeStep:
|
|
|
249
249
|
|
|
250
250
|
step_id: str = ""
|
|
251
251
|
step_index: int = 0
|
|
252
|
-
timestamp: datetime = field(default_factory=lambda: datetime.now(
|
|
252
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
253
253
|
turn_number: int | None = None
|
|
254
254
|
events: list[BaseEvent] = field(default_factory=list)
|
|
255
255
|
markov_blanket_messages: list[SessionEventMarkovBlanketMessage] = field(default_factory=list)
|
|
@@ -283,7 +283,7 @@ class SessionTrace:
|
|
|
283
283
|
"""
|
|
284
284
|
|
|
285
285
|
session_id: str = ""
|
|
286
|
-
created_at: datetime = field(default_factory=lambda: datetime.now(
|
|
286
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
287
287
|
session_time_steps: list[SessionTimeStep] = field(default_factory=list)
|
|
288
288
|
event_history: list[BaseEvent] = field(default_factory=list)
|
|
289
289
|
markov_blanket_message_history: list[SessionEventMarkovBlanketMessage] = field(
|
|
@@ -8,7 +8,7 @@ from __future__ import annotations
|
|
|
8
8
|
|
|
9
9
|
import uuid
|
|
10
10
|
from dataclasses import dataclass, field
|
|
11
|
-
from datetime import
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
12
|
from typing import Any, TypedDict, cast
|
|
13
13
|
|
|
14
14
|
from .lm_call_record_abstractions import (
|
|
@@ -180,8 +180,8 @@ def create_llm_call_record_from_response(
|
|
|
180
180
|
api_type=api_type,
|
|
181
181
|
provider=provider,
|
|
182
182
|
model_name=model_name,
|
|
183
|
-
started_at=started_at or datetime.now(
|
|
184
|
-
completed_at=completed_at or datetime.now(
|
|
183
|
+
started_at=started_at or datetime.now(timezone.utc),
|
|
184
|
+
completed_at=completed_at or datetime.now(timezone.utc),
|
|
185
185
|
latency_ms=latency_ms,
|
|
186
186
|
request_params=params,
|
|
187
187
|
input_messages=input_messages,
|
|
@@ -376,8 +376,8 @@ def create_llm_call_record_from_streaming(
|
|
|
376
376
|
api_type="responses", # Streaming typically from Responses API
|
|
377
377
|
provider=provider,
|
|
378
378
|
model_name=model_name,
|
|
379
|
-
started_at=started_at or datetime.now(
|
|
380
|
-
completed_at=completed_at or datetime.now(
|
|
379
|
+
started_at=started_at or datetime.now(timezone.utc),
|
|
380
|
+
completed_at=completed_at or datetime.now(timezone.utc),
|
|
381
381
|
latency_ms=latency_ms,
|
|
382
382
|
request_params=params,
|
|
383
383
|
input_messages=input_messages,
|