synth-ai 0.2.13.dev2__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (110) hide show
  1. examples/multi_step/configs/README_verilog_rl.md +77 -0
  2. examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
  3. examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
  4. examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
  5. examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
  6. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +5 -4
  7. examples/multi_step/configs/crafter_synth_backend.md +40 -0
  8. examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
  9. examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
  10. examples/multi_step/configs/verilog_rl_lora.toml +190 -0
  11. examples/multi_step/judges/crafter_backend_judge.py +220 -0
  12. examples/multi_step/judges/verilog_backend_judge.py +234 -0
  13. examples/multi_step/readme.md +48 -0
  14. examples/multi_step/verilog_rl_lora.md +218 -0
  15. examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
  16. examples/sft/evaluate.py +2 -0
  17. examples/sft/generate_traces.py +2 -0
  18. examples/swe/task_app/grpo_swe_mini.py +1 -0
  19. examples/swe/task_app/hosted/rollout.py +2 -0
  20. examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
  21. examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
  22. examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
  23. examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
  24. examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
  25. examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
  26. examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
  27. examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
  28. examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
  29. examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
  30. examples/task_apps/crafter/task_app/__init__.py +3 -0
  31. examples/task_apps/crafter/task_app/grpo_crafter.py +306 -8
  32. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
  33. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +16 -3
  34. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
  35. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
  36. examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +52 -1
  37. examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +111 -13
  38. examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +156 -0
  39. examples/task_apps/enron/filter_sft.toml +5 -0
  40. examples/task_apps/enron/tests/__init__.py +2 -0
  41. examples/task_apps/enron/tests/integration/__init__.py +2 -0
  42. examples/task_apps/enron/tests/integration/test_enron_eval.py +2 -0
  43. examples/task_apps/enron/tests/unit/__init__.py +2 -0
  44. examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
  45. examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
  46. examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
  47. examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
  48. examples/task_apps/pokemon_red/pallet_town_rl_config.toml +2 -0
  49. examples/task_apps/pokemon_red/task_app.py +199 -6
  50. examples/task_apps/pokemon_red/test_pallet_town_rewards.py +2 -0
  51. examples/task_apps/sokoban/filter_sft.toml +5 -0
  52. examples/task_apps/sokoban/tests/__init__.py +2 -0
  53. examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
  54. examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
  55. examples/task_apps/verilog/eval_groq_qwen32b.toml +8 -4
  56. examples/task_apps/verilog/filter_sft.toml +5 -0
  57. examples/task_apps/verilog/task_app/grpo_verilog.py +258 -23
  58. examples/task_apps/verilog/tests/__init__.py +2 -0
  59. examples/task_apps/verilog/tests/integration/__init__.py +2 -0
  60. examples/task_apps/verilog/tests/integration/test_verilog_eval.py +2 -0
  61. examples/task_apps/verilog/tests/unit/__init__.py +2 -0
  62. examples/warming_up_to_rl/groq_test.py +2 -0
  63. examples/warming_up_to_rl/run_local_rollout.py +2 -0
  64. examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
  65. examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
  66. examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
  67. examples/warming_up_to_rl/run_rollout_remote.py +2 -0
  68. synth_ai/api/models/supported.py +1 -0
  69. synth_ai/cli/__init__.py +46 -13
  70. synth_ai/cli/_modal_wrapper.py +3 -2
  71. synth_ai/cli/recent.py +1 -1
  72. synth_ai/cli/status.py +1 -1
  73. synth_ai/cli/task_apps.py +354 -143
  74. synth_ai/cli/traces.py +1 -1
  75. synth_ai/cli/tui.py +57 -0
  76. synth_ai/cli/turso.py +1 -1
  77. synth_ai/cli/watch.py +1 -1
  78. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
  79. synth_ai/environments/examples/crafter_classic/environment.py +1 -1
  80. synth_ai/environments/examples/verilog/engine.py +76 -10
  81. synth_ai/judge_schemas.py +8 -8
  82. synth_ai/task/__init__.py +11 -1
  83. synth_ai/task/apps/__init__.py +1 -0
  84. synth_ai/task/config.py +257 -0
  85. synth_ai/task/contracts.py +15 -2
  86. synth_ai/task/rubrics/__init__.py +3 -0
  87. synth_ai/task/rubrics/loaders.py +22 -3
  88. synth_ai/task/rubrics/scoring.py +3 -0
  89. synth_ai/task/trace_correlation_helpers.py +315 -0
  90. synth_ai/task/validators.py +144 -0
  91. synth_ai/tracing_v3/abstractions.py +3 -3
  92. synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
  93. synth_ai/tracing_v3/session_tracer.py +16 -6
  94. synth_ai/tracing_v3/storage/base.py +29 -29
  95. synth_ai/tracing_v3/storage/config.py +3 -3
  96. synth_ai/tracing_v3/turso/daemon.py +8 -7
  97. synth_ai/tracing_v3/turso/native_manager.py +63 -40
  98. synth_ai/tracing_v3/utils.py +3 -3
  99. synth_ai/tui/__init__.py +5 -0
  100. synth_ai/tui/__main__.py +13 -0
  101. synth_ai/tui/cli/__init__.py +1 -0
  102. synth_ai/tui/cli/query_experiments.py +164 -0
  103. synth_ai/tui/cli/query_experiments_v3.py +164 -0
  104. synth_ai/tui/dashboard.py +906 -0
  105. {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/METADATA +1 -1
  106. {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/RECORD +110 -71
  107. {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
  108. {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
  109. {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
  110. {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
+ from enum import Enum
4
5
  from typing import Any, Literal
5
6
 
6
7
  from pydantic import BaseModel, ConfigDict, Field
7
8
 
8
9
 
10
+ class RolloutMode(str, Enum):
11
+ """Mode controls how rollout infrastructure processes inference URLs."""
12
+ RL = "rl"
13
+ EVAL = "eval"
14
+
15
+
9
16
  @dataclass(frozen=True)
10
17
  class TaskAppEndpoints:
11
18
  """Required Task App endpoints used by RL trainers and clients.
@@ -43,7 +50,7 @@ class RolloutRecordConfig(BaseModel):
43
50
  logprobs: bool = False
44
51
  value: bool = False
45
52
  return_trace: bool = False
46
- trace_format: Literal["compact", "full"] = "compact"
53
+ trace_format: Literal["compact", "full", "structured"] = "compact"
47
54
 
48
55
 
49
56
  class RolloutSafetyConfig(BaseModel):
@@ -61,6 +68,7 @@ class RolloutRequest(BaseModel):
61
68
  safety: RolloutSafetyConfig = RolloutSafetyConfig()
62
69
  training_session_id: str | None = None
63
70
  synth_base_url: str | None = None
71
+ mode: RolloutMode # Required: explicit RL vs EVAL mode
64
72
 
65
73
 
66
74
  class RolloutStep(BaseModel):
@@ -110,7 +118,7 @@ class RolloutTrajectory(BaseModel):
110
118
 
111
119
  # Required for trace correlation with inference mesh (optional initially for backward compat)
112
120
  # See: monorepo/INFERENCE_URL_REQUIREMENT_PLAN.md and trace_creation_and_judgement.txt
113
- inference_url: str | None = None
121
+ inference_url: str
114
122
 
115
123
  decision_samples: list[dict[str, Any]] | None = None
116
124
 
@@ -143,10 +151,15 @@ class RolloutResponse(BaseModel):
143
151
  aborted: bool = False
144
152
  ops_executed: int = 0
145
153
 
154
+ # OPTIONAL: correlation ID for linking rollout to inference traces
155
+ # If not provided, trainer will infer it from trajectory.inference_url ?cid=... parameter
156
+ trace_correlation_id: str | None = None
157
+
146
158
  # PREFERRED: v3 trace format (SessionTrace). This is the single source of truth
147
159
  # for rollout data and should be used by all new code. Contains richer data than
148
160
  # trajectories including token IDs, logprobs, timing, and multimodal content.
149
161
  trace: dict[str, Any] | None = None
162
+ pipeline_metadata: dict[str, Any] = Field(default_factory=dict)
150
163
 
151
164
 
152
165
  class _ExtraAllowModel(BaseModel):
@@ -51,3 +51,6 @@ __all__ = [
51
51
  RubricCriterion = StrictCriterion
52
52
  RubricSpec = StrictRubric
53
53
 
54
+
55
+
56
+
@@ -60,15 +60,34 @@ def load_rubric(source: str | dict[str, Any] | Rubric | None) -> Rubric | None:
60
60
 
61
61
  Returns:
62
62
  Parsed Rubric instance or None if source is None
63
+
64
+ Raises:
65
+ ValueError: If the rubric format is incorrect (e.g., backend judge format)
66
+ ValidationError: If the rubric fails schema validation
63
67
  """
64
68
  if source is None:
65
69
  return None
66
70
  if isinstance(source, Rubric):
67
71
  return source
72
+
73
+ # Load and parse the data
68
74
  if isinstance(source, dict):
69
- return Rubric.model_validate(source)
70
- text, suffix = _load_text(str(source))
71
- data = _parse_structured(text, suffix)
75
+ data = source
76
+ else:
77
+ text, suffix = _load_text(str(source))
78
+ data = _parse_structured(text, suffix)
79
+
80
+ # Check if this looks like a backend judge rubric (wrong format)
81
+ if isinstance(data, dict) and "event" in data and "outcome" in data:
82
+ # Missing required task app rubric fields
83
+ if "version" not in data and "goal_text" not in data and "criteria" not in data:
84
+ source_hint = f" ({source})" if isinstance(source, str) else ""
85
+ raise ValueError(
86
+ f"Rubric appears to be in backend judge format (has 'event'/'outcome' keys){source_hint}. "
87
+ f"Task apps require rubrics with 'version', 'goal_text', and 'criteria' fields. "
88
+ f"Backend judge rubrics should be named '*_backend_judge.json' and loaded by judge functions."
89
+ )
90
+
72
91
  return Rubric.model_validate(data)
73
92
 
74
93
 
@@ -111,3 +111,6 @@ def score_outcome_against_rubric(outcome: dict[str, Any], rubric: Rubric | None)
111
111
  values[str(key)] = score
112
112
  return _score(rubric.criteria, values, rubric.aggregation)
113
113
 
114
+
115
+
116
+
@@ -0,0 +1,315 @@
1
+ """Helpers for trace correlation ID extraction and inclusion in task apps.
2
+
3
+ This module provides utilities for task apps to:
4
+ 1. Extract trace_correlation_id from rollout requests
5
+ 2. Include trace_correlation_id in rollout responses (3 required locations)
6
+
7
+ See monorepo/trace_creation_and_judgement.txt "Fatal Guards" section for requirements.
8
+ """
9
+
10
+ import logging
11
+ from typing import Any
12
+ from urllib.parse import parse_qs, urlparse
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def extract_trace_correlation_id(
18
+ policy_config: dict[str, Any],
19
+ inference_url: str | None = None,
20
+ mode: Any = None
21
+ ) -> str | None:
22
+ """
23
+ Extract trace_correlation_id from policy config or inference URL.
24
+
25
+ This is the standardized method for all task apps to extract the correlation ID
26
+ that the RL trainer generates and passes to the task app.
27
+
28
+ Args:
29
+ policy_config: Policy configuration dict from RolloutRequest.policy.config
30
+ inference_url: Inference URL (optional, used as fallback)
31
+ mode: RolloutMode or string ("rl" or "eval"). Controls warning behavior -
32
+ warnings only logged for RL mode, not EVAL mode.
33
+
34
+ Returns:
35
+ trace_correlation_id if found, None otherwise
36
+
37
+ Extraction order:
38
+ 1. policy_config["trace_correlation_id"] (preferred)
39
+ 2. policy_config["trace"] (legacy fallback)
40
+ 3. URL query param ?cid=... (fallback)
41
+ 4. URL query param ?trace_correlation_id=... (fallback)
42
+ """
43
+ # Try policy_config first (preferred method)
44
+ candidates: list[Any] = [
45
+ policy_config.get("trace_correlation_id"),
46
+ policy_config.get("trace"),
47
+ ]
48
+
49
+ logger.debug(
50
+ "extract_trace_correlation_id: policy_cfg keys=%s candidates=%s",
51
+ sorted(policy_config.keys()),
52
+ candidates,
53
+ )
54
+
55
+ for candidate in candidates:
56
+ if isinstance(candidate, str):
57
+ stripped = candidate.strip()
58
+ if stripped:
59
+ logger.info(
60
+ "extract_trace_correlation_id: extracted from policy_config=%s",
61
+ stripped
62
+ )
63
+ return stripped
64
+
65
+ # Determine if we're in EVAL mode (trace_correlation_id not required for eval)
66
+ try:
67
+ from synth_ai.task.contracts import RolloutMode
68
+ is_eval_mode = (mode == "eval" or mode == RolloutMode.EVAL or
69
+ (hasattr(mode, 'value') and mode.value == "eval"))
70
+ except ImportError:
71
+ # If RolloutMode not available, fall back to string comparison
72
+ is_eval_mode = (mode == "eval")
73
+
74
+ # Fallback: try to extract from inference_url query params
75
+ if not inference_url or not isinstance(inference_url, str):
76
+ if is_eval_mode:
77
+ logger.debug(
78
+ "extract_trace_correlation_id: no correlation ID found in policy_config "
79
+ "and no inference_url provided (EVAL mode - expected)"
80
+ )
81
+ else:
82
+ logger.warning(
83
+ "extract_trace_correlation_id: no correlation ID found in policy_config "
84
+ "and no inference_url provided"
85
+ )
86
+ return None
87
+
88
+ try:
89
+ parsed = urlparse(inference_url)
90
+ query_params = parse_qs(parsed.query or "")
91
+ # Try multiple possible query param names
92
+ for param_name in ["cid", "trace_correlation_id", "trace"]:
93
+ values = query_params.get(param_name, [])
94
+ for value in values:
95
+ if isinstance(value, str) and value.strip():
96
+ correlation_id = value.strip()
97
+ logger.info(
98
+ "extract_trace_correlation_id: extracted from URL param %s=%s",
99
+ param_name,
100
+ correlation_id,
101
+ )
102
+ return correlation_id
103
+ except Exception as e:
104
+ logger.warning(
105
+ "extract_trace_correlation_id: failed to parse inference_url=%s error=%s",
106
+ inference_url,
107
+ e,
108
+ )
109
+
110
+ if is_eval_mode:
111
+ logger.debug(
112
+ "extract_trace_correlation_id: no trace_correlation_id found in "
113
+ "policy_config or inference_url=%s (EVAL mode - expected)",
114
+ inference_url,
115
+ )
116
+ else:
117
+ logger.warning(
118
+ "extract_trace_correlation_id: no trace_correlation_id found in "
119
+ "policy_config or inference_url=%s",
120
+ inference_url,
121
+ )
122
+ return None
123
+
124
+
125
+ def validate_trace_correlation_id(
126
+ trace_correlation_id: str | None,
127
+ run_id: str,
128
+ policy_config: dict[str, Any],
129
+ fatal: bool = False
130
+ ) -> str | None:
131
+ """
132
+ Validate that trace_correlation_id was successfully extracted.
133
+
134
+ Args:
135
+ trace_correlation_id: The extracted correlation ID (or None)
136
+ run_id: Rollout run_id for logging
137
+ policy_config: Policy configuration for debugging
138
+ fatal: If True, raise ValueError on missing ID. If False, log error only.
139
+
140
+ Returns:
141
+ trace_correlation_id if present, None if missing (when fatal=False)
142
+
143
+ Raises:
144
+ ValueError: If trace_correlation_id is missing and fatal=True
145
+ """
146
+ if not trace_correlation_id:
147
+ error_msg = (
148
+ f"🚨 CRITICAL: Cannot extract trace_correlation_id!\n"
149
+ "\n"
150
+ f"Run ID: {run_id}\n"
151
+ f"Policy config keys: {sorted(policy_config.keys())}\n"
152
+ f"Inference URL: {policy_config.get('inference_url', 'NOT_SET')}\n"
153
+ "\n"
154
+ "Checked:\n"
155
+ f"1. policy_config['trace_correlation_id']: {policy_config.get('trace_correlation_id')}\n"
156
+ f"2. policy_config['trace']: {policy_config.get('trace')}\n"
157
+ f"3. inference_url query params\n"
158
+ "\n"
159
+ "Task app CANNOT proceed without trace_correlation_id.\n"
160
+ "This indicates the RL trainer is not sending it correctly.\n"
161
+ "\n"
162
+ "See monorepo/trace_creation_and_judgement.txt 'Fatal Guards' section.\n"
163
+ )
164
+
165
+ if fatal:
166
+ raise ValueError(error_msg)
167
+ else:
168
+ logger.error(error_msg)
169
+
170
+ return trace_correlation_id
171
+
172
+
173
+ def include_trace_correlation_id_in_response(
174
+ response_data: dict[str, Any],
175
+ trace_correlation_id: str | None,
176
+ run_id: str
177
+ ) -> dict[str, Any]:
178
+ """
179
+ Include trace_correlation_id in all required locations of rollout response.
180
+
181
+ Required locations (per Fatal Guards section):
182
+ 1. Top-level response["trace_correlation_id"]
183
+ 2. response["pipeline_metadata"]["trace_correlation_id"]
184
+ 3. Each trajectory["trace_correlation_id"]
185
+
186
+ Args:
187
+ response_data: RolloutResponse dict (from .model_dump())
188
+ trace_correlation_id: The correlation ID to include
189
+ run_id: Rollout run_id for logging
190
+
191
+ Returns:
192
+ Modified response_data with trace_correlation_id in all required places
193
+ """
194
+ if not trace_correlation_id:
195
+ logger.error(
196
+ "include_trace_correlation_id_in_response: missing trace_correlation_id "
197
+ "for run_id=%s - cannot include in response",
198
+ run_id
199
+ )
200
+ return response_data
201
+
202
+ # 1. Add to top-level (REQUIRED)
203
+ if "trace_correlation_id" not in response_data:
204
+ response_data["trace_correlation_id"] = trace_correlation_id
205
+ logger.info(
206
+ "include_trace_correlation_id: added to top-level run_id=%s cid=%s",
207
+ run_id,
208
+ trace_correlation_id
209
+ )
210
+
211
+ # 2. Add to pipeline_metadata (REQUIRED)
212
+ pipeline_meta = response_data.get("pipeline_metadata")
213
+ if not isinstance(pipeline_meta, dict):
214
+ pipeline_meta = {}
215
+ response_data["pipeline_metadata"] = pipeline_meta
216
+
217
+ if "trace_correlation_id" not in pipeline_meta:
218
+ pipeline_meta["trace_correlation_id"] = trace_correlation_id
219
+ logger.info(
220
+ "include_trace_correlation_id: added to pipeline_metadata run_id=%s cid=%s",
221
+ run_id,
222
+ trace_correlation_id
223
+ )
224
+
225
+ # 3. Add to each trajectory (REQUIRED)
226
+ trajectories = response_data.get("trajectories", [])
227
+ if isinstance(trajectories, list):
228
+ for idx, traj in enumerate(trajectories):
229
+ if isinstance(traj, dict) and "trace_correlation_id" not in traj:
230
+ traj["trace_correlation_id"] = trace_correlation_id
231
+ logger.debug(
232
+ "include_trace_correlation_id: added to trajectory[%d] run_id=%s cid=%s",
233
+ idx,
234
+ run_id,
235
+ trace_correlation_id
236
+ )
237
+
238
+ logger.info(
239
+ "include_trace_correlation_id: completed run_id=%s cid=%s "
240
+ "added to %d locations (top-level, metadata, %d trajectories)",
241
+ run_id,
242
+ trace_correlation_id,
243
+ 2 + len(trajectories),
244
+ len(trajectories)
245
+ )
246
+
247
+ return response_data
248
+
249
+
250
+ def verify_trace_correlation_id_in_response(
251
+ response_data: dict[str, Any],
252
+ expected_correlation_id: str | None,
253
+ run_id: str
254
+ ) -> bool:
255
+ """
256
+ Verify that trace_correlation_id is present in all required locations.
257
+
258
+ Args:
259
+ response_data: RolloutResponse dict to verify
260
+ expected_correlation_id: The correlation ID that should be present
261
+ run_id: Rollout run_id for logging
262
+
263
+ Returns:
264
+ True if all required locations have the correlation ID, False otherwise
265
+ """
266
+ if not expected_correlation_id:
267
+ logger.error(
268
+ "verify_trace_correlation_id: no expected_correlation_id provided for run_id=%s",
269
+ run_id
270
+ )
271
+ return False
272
+
273
+ errors = []
274
+
275
+ # Check top-level
276
+ if response_data.get("trace_correlation_id") != expected_correlation_id:
277
+ errors.append(
278
+ f"Top-level missing or mismatch: "
279
+ f"expected={expected_correlation_id} actual={response_data.get('trace_correlation_id')}"
280
+ )
281
+
282
+ # Check pipeline_metadata
283
+ pipeline_meta = response_data.get("pipeline_metadata", {})
284
+ if not isinstance(pipeline_meta, dict) or pipeline_meta.get("trace_correlation_id") != expected_correlation_id:
285
+ errors.append(
286
+ f"pipeline_metadata missing or mismatch: "
287
+ f"expected={expected_correlation_id} actual={pipeline_meta.get('trace_correlation_id') if isinstance(pipeline_meta, dict) else 'NOT_A_DICT'}"
288
+ )
289
+
290
+ # Check trajectories
291
+ trajectories = response_data.get("trajectories", [])
292
+ if isinstance(trajectories, list):
293
+ for idx, traj in enumerate(trajectories):
294
+ if isinstance(traj, dict) and traj.get("trace_correlation_id") != expected_correlation_id:
295
+ errors.append(
296
+ f"trajectory[{idx}] missing or mismatch: "
297
+ f"expected={expected_correlation_id} actual={traj.get('trace_correlation_id')}"
298
+ )
299
+
300
+ if errors:
301
+ logger.error(
302
+ "verify_trace_correlation_id: FAILED run_id=%s\n%s",
303
+ run_id,
304
+ "\n".join(errors)
305
+ )
306
+ return False
307
+
308
+ logger.info(
309
+ "verify_trace_correlation_id: PASSED run_id=%s cid=%s",
310
+ run_id,
311
+ expected_correlation_id
312
+ )
313
+ return True
314
+
315
+
@@ -4,6 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  import re
6
6
  from typing import Any
7
+ from urllib.parse import urlparse, urlunparse
7
8
 
8
9
  import click
9
10
  import httpx
@@ -11,6 +12,149 @@ import httpx
11
12
  from synth_ai.task.contracts import TaskAppEndpoints # type: ignore[attr-defined]
12
13
 
13
14
 
15
+ def validate_rollout_response_for_rl(response_data: dict[str, Any], *, warn_only: bool = False) -> list[str]:
16
+ """Validate that a task app rollout response has required fields for RL training.
17
+
18
+ The backend RL trainer requires:
19
+ 1. pipeline_metadata["inference_url"] at top level (with ?cid= for trace correlation)
20
+ 2. Each step's info.meta["inference_url"] must be present (nested structure!)
21
+
22
+ Args:
23
+ response_data: The rollout response dict from task app
24
+ warn_only: If True, return warnings instead of raising exceptions
25
+
26
+ Returns:
27
+ List of validation warnings/errors
28
+
29
+ Raises:
30
+ ValueError: If critical fields are missing (unless warn_only=True)
31
+ """
32
+ issues = []
33
+
34
+ # Check pipeline_metadata
35
+ pipeline_metadata = response_data.get("pipeline_metadata")
36
+ if not isinstance(pipeline_metadata, dict):
37
+ issues.append("Missing or invalid 'pipeline_metadata' (required for RL training)")
38
+ else:
39
+ inference_url = pipeline_metadata.get("inference_url")
40
+ if not inference_url:
41
+ issues.append(
42
+ "pipeline_metadata['inference_url'] is missing. "
43
+ "RL trainer requires this field to extract traces."
44
+ )
45
+ elif not isinstance(inference_url, str):
46
+ issues.append(
47
+ f"pipeline_metadata['inference_url'] must be a string, got: {type(inference_url).__name__}"
48
+ )
49
+ elif "?cid=" not in inference_url:
50
+ issues.append(
51
+ f"pipeline_metadata['inference_url'] should contain '?cid=' for trace correlation. "
52
+ f"Got: {inference_url[:80]}..."
53
+ )
54
+
55
+ # Check trajectories and steps
56
+ trajectories = response_data.get("trajectories", [])
57
+ if not trajectories:
58
+ issues.append("No trajectories found in response")
59
+
60
+ for traj_idx, trajectory in enumerate(trajectories):
61
+ if not isinstance(trajectory, dict):
62
+ continue
63
+
64
+ steps = trajectory.get("steps", [])
65
+ for step_idx, step in enumerate(steps):
66
+ if not isinstance(step, dict):
67
+ continue
68
+
69
+ step_info = step.get("info", {})
70
+ if not isinstance(step_info, dict):
71
+ issues.append(
72
+ f"trajectory[{traj_idx}].steps[{step_idx}].info is not a dict"
73
+ )
74
+ continue
75
+
76
+ # Check for nested meta.inference_url (backend expects this structure!)
77
+ step_meta = step_info.get("meta", {})
78
+ if not isinstance(step_meta, dict):
79
+ issues.append(
80
+ f"trajectory[{traj_idx}].steps[{step_idx}].info.meta is missing or not a dict. "
81
+ f"RL trainer expects nested structure: info.meta.inference_url"
82
+ )
83
+ continue
84
+
85
+ step_inference_url = step_meta.get("inference_url")
86
+ if not step_inference_url:
87
+ issues.append(
88
+ f"trajectory[{traj_idx}].steps[{step_idx}].info.meta['inference_url'] is missing. "
89
+ f"RL trainer needs this for trace extraction (nested structure required!)"
90
+ )
91
+ elif not isinstance(step_inference_url, str):
92
+ issues.append(
93
+ f"trajectory[{traj_idx}].steps[{step_idx}].info.meta['inference_url'] must be a string, "
94
+ f"got: {type(step_inference_url).__name__}"
95
+ )
96
+
97
+ if issues and not warn_only:
98
+ error_msg = "Task app response validation failed for RL training:\n" + "\n".join(
99
+ f" - {issue}" for issue in issues
100
+ )
101
+ raise ValueError(error_msg)
102
+
103
+ return issues
104
+
105
+
106
+ def normalize_inference_url(url: str | None, *, default: str = "https://api.openai.com/v1/chat/completions") -> str:
107
+ """Normalize an inference URL to include the /v1/chat/completions path.
108
+
109
+ This utility ensures inference URLs have the correct path structure for OpenAI-compatible
110
+ chat completions endpoints, while preserving query parameters (e.g., ?cid=trace_123)
111
+ that may be added for tracing.
112
+
113
+ Args:
114
+ url: The inference URL to normalize (may be None or incomplete)
115
+ default: Default URL to use if url is None/empty
116
+
117
+ Returns:
118
+ Normalized URL with proper path and preserved query parameters
119
+
120
+ Examples:
121
+ >>> normalize_inference_url("https://api.groq.com")
122
+ 'https://api.groq.com/v1/chat/completions'
123
+
124
+ >>> normalize_inference_url("https://modal.host?cid=trace_123")
125
+ 'https://modal.host/v1/chat/completions?cid=trace_123'
126
+
127
+ >>> normalize_inference_url("https://api.openai.com/v1")
128
+ 'https://api.openai.com/v1/chat/completions'
129
+
130
+ >>> normalize_inference_url("https://api.groq.com/openai/v1/chat/completions")
131
+ 'https://api.groq.com/openai/v1/chat/completions'
132
+ """
133
+ candidate = (url or default).strip()
134
+ if not candidate:
135
+ candidate = default
136
+
137
+ # Parse the URL to separate path and query components
138
+ parsed = urlparse(candidate)
139
+
140
+ # Check if path already ends with a completions endpoint
141
+ path = parsed.path.rstrip('/')
142
+ if path.endswith("/v1/chat/completions") or path.endswith("/chat/completions"):
143
+ return candidate
144
+
145
+ # Determine what to append based on existing path
146
+ if path.endswith("/v1"):
147
+ new_path = f"{path}/chat/completions"
148
+ elif path.endswith("/chat"):
149
+ new_path = f"{path}/completions"
150
+ else:
151
+ # Default: append full path
152
+ new_path = f"{path}/v1/chat/completions" if path else "/v1/chat/completions"
153
+
154
+ # Reconstruct URL with new path and original query/fragment
155
+ return urlunparse(parsed._replace(path=new_path))
156
+
157
+
14
158
  def validate_task_app_url(url: str | None) -> str:
15
159
  """Validate and normalize a task app URL.
16
160
 
@@ -37,7 +37,7 @@ Concepts:
37
37
  from __future__ import annotations
38
38
 
39
39
  from dataclasses import asdict, dataclass, field
40
- from datetime import UTC, datetime
40
+ from datetime import datetime, timezone
41
41
  from typing import Any
42
42
 
43
43
  from .lm_call_record_abstractions import LLMCallRecord
@@ -249,7 +249,7 @@ class SessionTimeStep:
249
249
 
250
250
  step_id: str = ""
251
251
  step_index: int = 0
252
- timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
252
+ timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
253
253
  turn_number: int | None = None
254
254
  events: list[BaseEvent] = field(default_factory=list)
255
255
  markov_blanket_messages: list[SessionEventMarkovBlanketMessage] = field(default_factory=list)
@@ -283,7 +283,7 @@ class SessionTrace:
283
283
  """
284
284
 
285
285
  session_id: str = ""
286
- created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
286
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
287
287
  session_time_steps: list[SessionTimeStep] = field(default_factory=list)
288
288
  event_history: list[BaseEvent] = field(default_factory=list)
289
289
  markov_blanket_message_history: list[SessionEventMarkovBlanketMessage] = field(
@@ -8,7 +8,7 @@ from __future__ import annotations
8
8
 
9
9
  import uuid
10
10
  from dataclasses import dataclass, field
11
- from datetime import UTC, datetime
11
+ from datetime import datetime, timezone
12
12
  from typing import Any, TypedDict, cast
13
13
 
14
14
  from .lm_call_record_abstractions import (
@@ -180,8 +180,8 @@ def create_llm_call_record_from_response(
180
180
  api_type=api_type,
181
181
  provider=provider,
182
182
  model_name=model_name,
183
- started_at=started_at or datetime.now(UTC),
184
- completed_at=completed_at or datetime.now(UTC),
183
+ started_at=started_at or datetime.now(timezone.utc),
184
+ completed_at=completed_at or datetime.now(timezone.utc),
185
185
  latency_ms=latency_ms,
186
186
  request_params=params,
187
187
  input_messages=input_messages,
@@ -376,8 +376,8 @@ def create_llm_call_record_from_streaming(
376
376
  api_type="responses", # Streaming typically from Responses API
377
377
  provider=provider,
378
378
  model_name=model_name,
379
- started_at=started_at or datetime.now(UTC),
380
- completed_at=completed_at or datetime.now(UTC),
379
+ started_at=started_at or datetime.now(timezone.utc),
380
+ completed_at=completed_at or datetime.now(timezone.utc),
381
381
  latency_ms=latency_ms,
382
382
  request_params=params,
383
383
  input_messages=input_messages,