synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
This module provides utilities for task apps to:
|
|
4
4
|
1. Extract trace_correlation_id from rollout requests
|
|
5
|
-
2. Include trace_correlation_id in rollout responses (
|
|
5
|
+
2. Include trace_correlation_id in rollout responses (top-level, metadata, trace)
|
|
6
6
|
|
|
7
7
|
See monorepo/trace_creation_and_judgement.txt "Fatal Guards" section for requirements.
|
|
8
8
|
"""
|
|
@@ -101,6 +101,24 @@ def extract_trace_correlation_id(
|
|
|
101
101
|
|
|
102
102
|
try:
|
|
103
103
|
parsed = urlparse(inference_url)
|
|
104
|
+
|
|
105
|
+
# 1. Try path-based extraction first (OpenAI SDK compatible format):
|
|
106
|
+
# /v1/{trial_id}/{correlation_id}/chat/completions
|
|
107
|
+
path_segments = [s for s in parsed.path.split("/") if s]
|
|
108
|
+
if len(path_segments) >= 2:
|
|
109
|
+
# Check if path ends with chat/completions
|
|
110
|
+
if path_segments[-2:] == ["chat", "completions"] and len(path_segments) >= 3:
|
|
111
|
+
# correlation_id is the segment before chat/completions
|
|
112
|
+
potential_cid = path_segments[-3]
|
|
113
|
+
# Verify it looks like a correlation ID (starts with trace_ or cid_)
|
|
114
|
+
if potential_cid.startswith("trace_") or potential_cid.startswith("cid_"):
|
|
115
|
+
logger.info(
|
|
116
|
+
"extract_trace_correlation_id: extracted from URL path=%s",
|
|
117
|
+
potential_cid,
|
|
118
|
+
)
|
|
119
|
+
return potential_cid.strip()
|
|
120
|
+
|
|
121
|
+
# 2. Fall back to query param extraction (legacy format)
|
|
104
122
|
query_params = parse_qs(parsed.query or "")
|
|
105
123
|
# Try multiple possible query param names
|
|
106
124
|
for param_name in ["cid", "trace_correlation_id", "trace"]:
|
|
@@ -193,11 +211,11 @@ def include_trace_correlation_id_in_response(
|
|
|
193
211
|
) -> dict[str, Any]:
|
|
194
212
|
"""
|
|
195
213
|
Include trace_correlation_id in all required locations of rollout response.
|
|
196
|
-
|
|
197
|
-
Required locations (
|
|
214
|
+
|
|
215
|
+
Required locations (trace-only):
|
|
198
216
|
1. Top-level response["trace_correlation_id"]
|
|
199
217
|
2. response["pipeline_metadata"]["trace_correlation_id"]
|
|
200
|
-
3.
|
|
218
|
+
3. response["trace"]["metadata"]["trace_correlation_id"] (and session_trace metadata if present)
|
|
201
219
|
|
|
202
220
|
Args:
|
|
203
221
|
response_data: RolloutResponse dict (from .model_dump())
|
|
@@ -238,32 +256,42 @@ def include_trace_correlation_id_in_response(
|
|
|
238
256
|
trace_correlation_id
|
|
239
257
|
)
|
|
240
258
|
|
|
241
|
-
# 3. Add to
|
|
242
|
-
|
|
243
|
-
if isinstance(
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
259
|
+
# 3. Add to trace metadata (REQUIRED)
|
|
260
|
+
trace_block = response_data.get("trace")
|
|
261
|
+
if isinstance(trace_block, dict):
|
|
262
|
+
trace_meta = trace_block.get("metadata")
|
|
263
|
+
if not isinstance(trace_meta, dict):
|
|
264
|
+
trace_meta = {}
|
|
265
|
+
trace_block["metadata"] = trace_meta
|
|
266
|
+
if "trace_correlation_id" not in trace_meta:
|
|
267
|
+
trace_meta["trace_correlation_id"] = trace_correlation_id
|
|
268
|
+
corr_ids = trace_meta.get("correlation_ids")
|
|
269
|
+
if isinstance(corr_ids, dict):
|
|
270
|
+
corr_map = dict(corr_ids)
|
|
271
|
+
else:
|
|
272
|
+
corr_map = {}
|
|
273
|
+
corr_map.setdefault("trace_correlation_id", trace_correlation_id)
|
|
274
|
+
trace_meta["correlation_ids"] = corr_map
|
|
275
|
+
|
|
276
|
+
session_trace = trace_block.get("session_trace")
|
|
277
|
+
if isinstance(session_trace, dict):
|
|
278
|
+
session_meta = session_trace.get("metadata")
|
|
279
|
+
if not isinstance(session_meta, dict):
|
|
280
|
+
session_meta = {}
|
|
281
|
+
session_trace["metadata"] = session_meta
|
|
282
|
+
session_meta.setdefault("trace_correlation_id", trace_correlation_id)
|
|
283
|
+
|
|
254
284
|
logger.info(
|
|
255
285
|
"include_trace_correlation_id: completed run_id=%s cid=%s "
|
|
256
|
-
"added to
|
|
286
|
+
"added to top-level, metadata, and trace",
|
|
257
287
|
run_id,
|
|
258
288
|
trace_correlation_id,
|
|
259
|
-
2 + len(trajectories),
|
|
260
|
-
len(trajectories)
|
|
261
289
|
)
|
|
262
290
|
|
|
263
291
|
return response_data
|
|
264
292
|
|
|
265
293
|
|
|
266
|
-
def
|
|
294
|
+
def build_trace_payload(
|
|
267
295
|
messages: list[dict[str, Any]],
|
|
268
296
|
response: dict[str, Any] | None = None,
|
|
269
297
|
*,
|
|
@@ -272,11 +300,7 @@ def build_trajectory_trace(
|
|
|
272
300
|
metadata: dict[str, Any] | None = None,
|
|
273
301
|
) -> dict[str, Any]:
|
|
274
302
|
"""
|
|
275
|
-
Build a
|
|
276
|
-
|
|
277
|
-
This creates the trace structure required by monorepo's trace_validation.py:
|
|
278
|
-
- trajectory.trace.event_history must be non-empty
|
|
279
|
-
- event_history contains LM call records for input/output extraction
|
|
303
|
+
Build a v3 trace payload with event_history for trace-only responses.
|
|
280
304
|
|
|
281
305
|
Args:
|
|
282
306
|
messages: The messages sent to the LLM (input)
|
|
@@ -286,160 +310,178 @@ def build_trajectory_trace(
|
|
|
286
310
|
metadata: Optional additional metadata
|
|
287
311
|
|
|
288
312
|
Returns:
|
|
289
|
-
A trace dict with event_history suitable for
|
|
290
|
-
|
|
291
|
-
Example:
|
|
292
|
-
trace = build_trajectory_trace(
|
|
293
|
-
messages=[{"role": "user", "content": "Hello"}],
|
|
294
|
-
response={"choices": [{"message": {"content": "Hi!"}}]},
|
|
295
|
-
correlation_id="trace_abc123",
|
|
296
|
-
)
|
|
297
|
-
trajectory = RolloutTrajectory(..., trace=trace)
|
|
313
|
+
A trace dict with event_history suitable for RolloutResponse.trace
|
|
298
314
|
"""
|
|
299
315
|
import uuid
|
|
300
316
|
from datetime import datetime
|
|
301
317
|
|
|
302
|
-
# Build event_history with LM call record
|
|
303
318
|
event_history: list[dict[str, Any]] = []
|
|
304
319
|
|
|
305
|
-
|
|
306
|
-
|
|
320
|
+
llm_response: dict[str, Any] = {}
|
|
321
|
+
if isinstance(response, dict):
|
|
322
|
+
if "message" in response:
|
|
323
|
+
llm_response = dict(response)
|
|
324
|
+
elif "choices" in response and isinstance(response.get("choices"), list) and response["choices"]:
|
|
325
|
+
first_choice = response["choices"][0] if isinstance(response["choices"][0], dict) else {}
|
|
326
|
+
llm_response = {
|
|
327
|
+
"message": first_choice.get("message") if isinstance(first_choice, dict) else {},
|
|
328
|
+
"usage": response.get("usage", {}),
|
|
329
|
+
"finish_reason": first_choice.get("finish_reason") if isinstance(first_choice, dict) else None,
|
|
330
|
+
}
|
|
331
|
+
else:
|
|
332
|
+
llm_response = dict(response)
|
|
333
|
+
|
|
334
|
+
llm_event: dict[str, Any] = {
|
|
335
|
+
"type": "lm_call",
|
|
307
336
|
"event_type": "lm_call",
|
|
308
337
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
309
|
-
"
|
|
310
|
-
|
|
311
|
-
"response": response or {},
|
|
312
|
-
},
|
|
338
|
+
"llm_request": {"messages": messages},
|
|
339
|
+
"llm_response": llm_response,
|
|
313
340
|
}
|
|
314
341
|
|
|
315
342
|
# Add correlation ID if provided
|
|
316
343
|
if correlation_id:
|
|
317
|
-
|
|
344
|
+
llm_event["correlation_id"] = correlation_id
|
|
318
345
|
|
|
319
|
-
event_history.append(
|
|
346
|
+
event_history.append(llm_event)
|
|
347
|
+
|
|
348
|
+
trace_metadata: dict[str, Any] = dict(metadata or {})
|
|
349
|
+
trace_metadata.setdefault("session_id", session_id or str(uuid.uuid4()))
|
|
350
|
+
if correlation_id:
|
|
351
|
+
trace_metadata.setdefault("trace_correlation_id", correlation_id)
|
|
352
|
+
corr_ids = trace_metadata.get("correlation_ids")
|
|
353
|
+
if isinstance(corr_ids, dict):
|
|
354
|
+
corr_map = dict(corr_ids)
|
|
355
|
+
else:
|
|
356
|
+
corr_map = {}
|
|
357
|
+
corr_map.setdefault("trace_correlation_id", correlation_id)
|
|
358
|
+
trace_metadata["correlation_ids"] = corr_map
|
|
320
359
|
|
|
321
360
|
trace: dict[str, Any] = {
|
|
322
|
-
"
|
|
361
|
+
"schema_version": "3.0",
|
|
323
362
|
"event_history": event_history,
|
|
324
|
-
"
|
|
363
|
+
"markov_blanket_message_history": [],
|
|
364
|
+
"metadata": trace_metadata,
|
|
325
365
|
}
|
|
326
366
|
|
|
327
|
-
if correlation_id:
|
|
328
|
-
trace["correlation_id"] = correlation_id
|
|
329
|
-
|
|
330
|
-
if metadata:
|
|
331
|
-
trace["metadata"] = metadata
|
|
332
|
-
|
|
333
367
|
logger.debug(
|
|
334
|
-
"
|
|
368
|
+
"build_trace_payload: created trace with %d events, session_id=%s, cid=%s",
|
|
335
369
|
len(event_history),
|
|
336
|
-
|
|
370
|
+
trace_metadata.get("session_id"),
|
|
337
371
|
correlation_id,
|
|
338
372
|
)
|
|
339
373
|
|
|
340
374
|
return trace
|
|
341
375
|
|
|
342
376
|
|
|
343
|
-
def
|
|
377
|
+
def build_trajectory_trace(
|
|
378
|
+
messages: list[dict[str, Any]],
|
|
379
|
+
response: dict[str, Any] | None = None,
|
|
380
|
+
*,
|
|
381
|
+
correlation_id: str | None = None,
|
|
382
|
+
session_id: str | None = None,
|
|
383
|
+
metadata: dict[str, Any] | None = None,
|
|
384
|
+
) -> dict[str, Any]:
|
|
385
|
+
"""Backward-compatible alias for build_trace_payload."""
|
|
386
|
+
|
|
387
|
+
return build_trace_payload(
|
|
388
|
+
messages=messages,
|
|
389
|
+
response=response,
|
|
390
|
+
correlation_id=correlation_id,
|
|
391
|
+
session_id=session_id,
|
|
392
|
+
metadata=metadata,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def include_event_history_in_response(
|
|
344
397
|
response_data: dict[str, Any],
|
|
345
|
-
|
|
346
|
-
|
|
398
|
+
messages: list[dict[str, Any]] | None = None,
|
|
399
|
+
response: dict[str, Any] | None = None,
|
|
347
400
|
*,
|
|
348
401
|
run_id: str,
|
|
349
402
|
correlation_id: str | None = None,
|
|
350
403
|
) -> dict[str, Any]:
|
|
351
404
|
"""
|
|
352
|
-
Ensure
|
|
353
|
-
|
|
354
|
-
This satisfies monorepo's trace_validation.py requirement:
|
|
355
|
-
- validate_response_has_hydrated_trace() checks for event_history
|
|
405
|
+
Ensure response.trace includes a v3 event_history payload.
|
|
356
406
|
|
|
357
407
|
Args:
|
|
358
408
|
response_data: RolloutResponse dict (from .model_dump())
|
|
359
|
-
|
|
360
|
-
|
|
409
|
+
messages: Messages for the LLM call (for building event_history)
|
|
410
|
+
response: LLM response payload
|
|
361
411
|
run_id: Rollout run_id for logging
|
|
362
412
|
correlation_id: Trace correlation ID
|
|
363
413
|
|
|
364
414
|
Returns:
|
|
365
|
-
Modified response_data with event_history in
|
|
415
|
+
Modified response_data with event_history in response.trace
|
|
366
416
|
"""
|
|
367
|
-
|
|
368
|
-
if not isinstance(
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
run_id,
|
|
372
|
-
)
|
|
373
|
-
return response_data
|
|
417
|
+
trace_block = response_data.get("trace")
|
|
418
|
+
if not isinstance(trace_block, dict):
|
|
419
|
+
trace_block = {}
|
|
420
|
+
response_data["trace"] = trace_block
|
|
374
421
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
422
|
+
event_history = trace_block.get("event_history")
|
|
423
|
+
session_trace = trace_block.get("session_trace")
|
|
424
|
+
if not event_history and isinstance(session_trace, dict):
|
|
425
|
+
event_history = session_trace.get("event_history")
|
|
378
426
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
if not isinstance(trace, dict):
|
|
382
|
-
trace = {}
|
|
383
|
-
traj["trace"] = trace
|
|
427
|
+
if isinstance(event_history, list) and event_history:
|
|
428
|
+
return response_data
|
|
384
429
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
idx,
|
|
392
|
-
len(event_history),
|
|
393
|
-
run_id,
|
|
394
|
-
)
|
|
395
|
-
continue
|
|
430
|
+
new_trace = build_trace_payload(
|
|
431
|
+
messages=messages or [],
|
|
432
|
+
response=response,
|
|
433
|
+
correlation_id=correlation_id,
|
|
434
|
+
metadata={"run_id": run_id},
|
|
435
|
+
)
|
|
396
436
|
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
if responses_by_trajectory and idx < len(responses_by_trajectory)
|
|
406
|
-
else None
|
|
407
|
-
)
|
|
437
|
+
# Merge new trace payload into the existing trace block.
|
|
438
|
+
trace_meta = trace_block.get("metadata")
|
|
439
|
+
if isinstance(trace_meta, dict):
|
|
440
|
+
merged_meta = dict(new_trace.get("metadata", {}))
|
|
441
|
+
merged_meta.update(trace_meta)
|
|
442
|
+
trace_block["metadata"] = merged_meta
|
|
443
|
+
else:
|
|
444
|
+
trace_block["metadata"] = new_trace.get("metadata", {})
|
|
408
445
|
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
if isinstance(obs, dict):
|
|
416
|
-
step_messages = obs.get("messages")
|
|
417
|
-
if isinstance(step_messages, list):
|
|
418
|
-
messages = step_messages
|
|
419
|
-
break
|
|
420
|
-
|
|
421
|
-
# Build the trace with event_history
|
|
422
|
-
new_trace = build_trajectory_trace(
|
|
423
|
-
messages=messages,
|
|
424
|
-
response=response,
|
|
425
|
-
correlation_id=correlation_id or traj.get("trace_correlation_id"),
|
|
426
|
-
metadata={"run_id": run_id, "trajectory_index": idx},
|
|
427
|
-
)
|
|
446
|
+
trace_block.setdefault("schema_version", new_trace.get("schema_version"))
|
|
447
|
+
trace_block["event_history"] = new_trace.get("event_history", [])
|
|
448
|
+
trace_block.setdefault(
|
|
449
|
+
"markov_blanket_message_history",
|
|
450
|
+
new_trace.get("markov_blanket_message_history", []),
|
|
451
|
+
)
|
|
428
452
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
logger.info(
|
|
433
|
-
"include_event_history_in_trajectories: added event_history to "
|
|
434
|
-
"trajectory[%d] run_id=%s events=%d",
|
|
435
|
-
idx,
|
|
436
|
-
run_id,
|
|
437
|
-
len(trace.get("event_history", [])),
|
|
438
|
-
)
|
|
453
|
+
if isinstance(session_trace, dict) and "event_history" not in session_trace:
|
|
454
|
+
session_trace["event_history"] = trace_block["event_history"]
|
|
439
455
|
|
|
456
|
+
logger.info(
|
|
457
|
+
"include_event_history_in_response: added event_history run_id=%s events=%d",
|
|
458
|
+
run_id,
|
|
459
|
+
len(trace_block.get("event_history", [])),
|
|
460
|
+
)
|
|
440
461
|
return response_data
|
|
441
462
|
|
|
442
463
|
|
|
464
|
+
def include_event_history_in_trajectories(
|
|
465
|
+
response_data: dict[str, Any],
|
|
466
|
+
messages_by_trajectory: list[list[dict[str, Any]]] | None = None,
|
|
467
|
+
responses_by_trajectory: list[dict[str, Any]] | None = None,
|
|
468
|
+
*,
|
|
469
|
+
run_id: str,
|
|
470
|
+
correlation_id: str | None = None,
|
|
471
|
+
) -> dict[str, Any]:
|
|
472
|
+
"""Backward-compatible alias for include_event_history_in_response."""
|
|
473
|
+
|
|
474
|
+
messages = messages_by_trajectory[0] if messages_by_trajectory else None
|
|
475
|
+
response = responses_by_trajectory[0] if responses_by_trajectory else None
|
|
476
|
+
return include_event_history_in_response(
|
|
477
|
+
response_data,
|
|
478
|
+
messages=messages,
|
|
479
|
+
response=response,
|
|
480
|
+
run_id=run_id,
|
|
481
|
+
correlation_id=correlation_id,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
|
|
443
485
|
def verify_trace_correlation_id_in_response(
|
|
444
486
|
response_data: dict[str, Any],
|
|
445
487
|
expected_correlation_id: str | None,
|
|
@@ -480,15 +522,24 @@ def verify_trace_correlation_id_in_response(
|
|
|
480
522
|
f"expected={expected_correlation_id} actual={pipeline_meta.get('trace_correlation_id') if isinstance(pipeline_meta, dict) else 'NOT_A_DICT'}"
|
|
481
523
|
)
|
|
482
524
|
|
|
483
|
-
# Check
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
525
|
+
# Check trace metadata
|
|
526
|
+
trace_block = response_data.get("trace")
|
|
527
|
+
trace_meta_id = None
|
|
528
|
+
if isinstance(trace_block, dict):
|
|
529
|
+
trace_meta = trace_block.get("metadata")
|
|
530
|
+
if isinstance(trace_meta, dict):
|
|
531
|
+
trace_meta_id = trace_meta.get("trace_correlation_id")
|
|
532
|
+
if trace_meta_id != expected_correlation_id:
|
|
533
|
+
session_trace = trace_block.get("session_trace")
|
|
534
|
+
if isinstance(session_trace, dict):
|
|
535
|
+
session_meta = session_trace.get("metadata")
|
|
536
|
+
if isinstance(session_meta, dict):
|
|
537
|
+
trace_meta_id = session_meta.get("trace_correlation_id")
|
|
538
|
+
if trace_meta_id != expected_correlation_id:
|
|
539
|
+
errors.append(
|
|
540
|
+
"trace.metadata missing or mismatch: "
|
|
541
|
+
f"expected={expected_correlation_id} actual={trace_meta_id}"
|
|
542
|
+
)
|
|
492
543
|
|
|
493
544
|
if errors:
|
|
494
545
|
logger.error(
|
synth_ai/sdk/task/validators.py
CHANGED
|
@@ -16,8 +16,8 @@ def validate_rollout_response_for_rl(response_data: dict[str, Any], *, warn_only
|
|
|
16
16
|
"""Validate that a task app rollout response has required fields for RL training.
|
|
17
17
|
|
|
18
18
|
The backend RL trainer requires:
|
|
19
|
-
1.
|
|
20
|
-
2.
|
|
19
|
+
1. A v3 trace with event_history (preferred), OR
|
|
20
|
+
2. pipeline_metadata["inference_url"] with ?cid= for trace hydration fallback
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
23
|
response_data: The rollout response dict from task app
|
|
@@ -31,16 +31,43 @@ def validate_rollout_response_for_rl(response_data: dict[str, Any], *, warn_only
|
|
|
31
31
|
"""
|
|
32
32
|
issues = []
|
|
33
33
|
|
|
34
|
-
|
|
34
|
+
trace_block = response_data.get("trace")
|
|
35
|
+
event_history = None
|
|
36
|
+
if isinstance(trace_block, dict):
|
|
37
|
+
event_history = trace_block.get("event_history")
|
|
38
|
+
if not event_history and isinstance(trace_block.get("session_trace"), dict):
|
|
39
|
+
event_history = trace_block["session_trace"].get("event_history")
|
|
40
|
+
|
|
41
|
+
has_event_history = isinstance(event_history, list) and len(event_history) > 0
|
|
42
|
+
|
|
43
|
+
trace_correlation_id = response_data.get("trace_correlation_id")
|
|
44
|
+
if not trace_correlation_id and isinstance(trace_block, dict):
|
|
45
|
+
trace_meta = trace_block.get("metadata")
|
|
46
|
+
if isinstance(trace_meta, dict):
|
|
47
|
+
trace_correlation_id = trace_meta.get("trace_correlation_id")
|
|
48
|
+
|
|
49
|
+
if not trace_correlation_id:
|
|
50
|
+
issues.append(
|
|
51
|
+
"Missing trace_correlation_id (top-level or trace.metadata). "
|
|
52
|
+
"RL trainer requires this to link traces."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if not has_event_history:
|
|
56
|
+
issues.append(
|
|
57
|
+
"trace.event_history is missing or empty. "
|
|
58
|
+
"Return a v3 trace or provide inference_url for hydration."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Check pipeline_metadata inference_url only when trace is missing/empty
|
|
35
62
|
pipeline_metadata = response_data.get("pipeline_metadata")
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
else:
|
|
63
|
+
inference_url = None
|
|
64
|
+
if isinstance(pipeline_metadata, dict):
|
|
39
65
|
inference_url = pipeline_metadata.get("inference_url")
|
|
66
|
+
if not has_event_history:
|
|
40
67
|
if not inference_url:
|
|
41
68
|
issues.append(
|
|
42
69
|
"pipeline_metadata['inference_url'] is missing. "
|
|
43
|
-
"RL trainer
|
|
70
|
+
"RL trainer needs this to hydrate traces when event_history is absent."
|
|
44
71
|
)
|
|
45
72
|
elif not isinstance(inference_url, str):
|
|
46
73
|
issues.append(
|
|
@@ -52,48 +79,6 @@ def validate_rollout_response_for_rl(response_data: dict[str, Any], *, warn_only
|
|
|
52
79
|
f"Got: {inference_url[:80]}..."
|
|
53
80
|
)
|
|
54
81
|
|
|
55
|
-
# Check trajectories and steps
|
|
56
|
-
trajectories = response_data.get("trajectories", [])
|
|
57
|
-
if not trajectories:
|
|
58
|
-
issues.append("No trajectories found in response")
|
|
59
|
-
|
|
60
|
-
for traj_idx, trajectory in enumerate(trajectories):
|
|
61
|
-
if not isinstance(trajectory, dict):
|
|
62
|
-
continue
|
|
63
|
-
|
|
64
|
-
steps = trajectory.get("steps", [])
|
|
65
|
-
for step_idx, step in enumerate(steps):
|
|
66
|
-
if not isinstance(step, dict):
|
|
67
|
-
continue
|
|
68
|
-
|
|
69
|
-
step_info = step.get("info", {})
|
|
70
|
-
if not isinstance(step_info, dict):
|
|
71
|
-
issues.append(
|
|
72
|
-
f"trajectory[{traj_idx}].steps[{step_idx}].info is not a dict"
|
|
73
|
-
)
|
|
74
|
-
continue
|
|
75
|
-
|
|
76
|
-
# Check for nested meta.inference_url (backend expects this structure!)
|
|
77
|
-
step_meta = step_info.get("meta", {})
|
|
78
|
-
if not isinstance(step_meta, dict):
|
|
79
|
-
issues.append(
|
|
80
|
-
f"trajectory[{traj_idx}].steps[{step_idx}].info.meta is missing or not a dict. "
|
|
81
|
-
f"RL trainer expects nested structure: info.meta.inference_url"
|
|
82
|
-
)
|
|
83
|
-
continue
|
|
84
|
-
|
|
85
|
-
step_inference_url = step_meta.get("inference_url")
|
|
86
|
-
if not step_inference_url:
|
|
87
|
-
issues.append(
|
|
88
|
-
f"trajectory[{traj_idx}].steps[{step_idx}].info.meta['inference_url'] is missing. "
|
|
89
|
-
f"RL trainer needs this for trace extraction (nested structure required!)"
|
|
90
|
-
)
|
|
91
|
-
elif not isinstance(step_inference_url, str):
|
|
92
|
-
issues.append(
|
|
93
|
-
f"trajectory[{traj_idx}].steps[{step_idx}].info.meta['inference_url'] must be a string, "
|
|
94
|
-
f"got: {type(step_inference_url).__name__}"
|
|
95
|
-
)
|
|
96
|
-
|
|
97
82
|
if issues and not warn_only:
|
|
98
83
|
error_msg = "Task app response validation failed for RL training:\n" + "\n".join(
|
|
99
84
|
f" - {issue}" for issue in issues
|
|
@@ -4,7 +4,7 @@ This module provides high-level APIs for running training jobs:
|
|
|
4
4
|
- PromptLearningJob: GEPA and MIPRO prompt optimization
|
|
5
5
|
- SFTJob: Supervised fine-tuning
|
|
6
6
|
- RLJob: Reinforcement learning (GSPO, GRPO, PPO, etc.)
|
|
7
|
-
- GraphGenJob:
|
|
7
|
+
- GraphGenJob: Graph Opt (simplified workflows API)
|
|
8
8
|
|
|
9
9
|
Example:
|
|
10
10
|
from synth_ai.sdk.training import PromptLearningJob, RLJob, GraphGenJob
|
|
@@ -30,7 +30,7 @@ Example:
|
|
|
30
30
|
from __future__ import annotations
|
|
31
31
|
|
|
32
32
|
# Pollers and utilities
|
|
33
|
-
from synth_ai.sdk.api.train.pollers import JobPoller, PollOutcome, RLJobPoller
|
|
33
|
+
from synth_ai.sdk.api.train.pollers import JobPoller, PollOutcome, RLJobPoller, EvalJobPoller
|
|
34
34
|
|
|
35
35
|
# Re-export from existing locations
|
|
36
36
|
from synth_ai.sdk.api.train.prompt_learning import (
|
|
@@ -41,7 +41,7 @@ from synth_ai.sdk.api.train.prompt_learning import (
|
|
|
41
41
|
from synth_ai.sdk.api.train.rl import RLJob, RLJobConfig
|
|
42
42
|
from synth_ai.sdk.api.train.sft import SFTJob
|
|
43
43
|
|
|
44
|
-
# GraphGen (
|
|
44
|
+
# GraphGen (Graph Opt)
|
|
45
45
|
from synth_ai.sdk.api.train.graphgen import GraphGenJob, GraphGenJobResult, GraphGenSubmitResult
|
|
46
46
|
from synth_ai.sdk.api.train.graphgen_models import (
|
|
47
47
|
GraphGenJobConfig,
|
|
@@ -49,16 +49,7 @@ from synth_ai.sdk.api.train.graphgen_models import (
|
|
|
49
49
|
GraphGenTask,
|
|
50
50
|
GraphGenGoldOutput,
|
|
51
51
|
GraphGenRubric,
|
|
52
|
-
|
|
53
|
-
load_graphgen_taskset,
|
|
54
|
-
parse_graphgen_taskset,
|
|
55
|
-
# GraphGen aliases
|
|
56
|
-
GraphGenJobConfig,
|
|
57
|
-
GraphGenTaskSet,
|
|
58
|
-
GraphGenTask,
|
|
59
|
-
GraphGenGoldOutput,
|
|
60
|
-
GraphGenRubric,
|
|
61
|
-
GraphGenJudgeConfig,
|
|
52
|
+
GraphGenVerifierConfig,
|
|
62
53
|
load_graphgen_taskset,
|
|
63
54
|
parse_graphgen_taskset,
|
|
64
55
|
)
|
|
@@ -80,7 +71,7 @@ __all__ = [
|
|
|
80
71
|
"GraphGenTask",
|
|
81
72
|
"GraphGenGoldOutput",
|
|
82
73
|
"GraphGenRubric",
|
|
83
|
-
"
|
|
74
|
+
"GraphGenVerifierConfig",
|
|
84
75
|
"load_graphgen_taskset",
|
|
85
76
|
"parse_graphgen_taskset",
|
|
86
77
|
# GraphGen (legacy aliases)
|
|
@@ -92,11 +83,11 @@ __all__ = [
|
|
|
92
83
|
"GraphGenTask",
|
|
93
84
|
"GraphGenGoldOutput",
|
|
94
85
|
"GraphGenRubric",
|
|
95
|
-
"
|
|
86
|
+
"GraphGenVerifierConfig",
|
|
96
87
|
"load_graphgen_taskset",
|
|
97
88
|
"parse_graphgen_taskset",
|
|
98
89
|
# Utils
|
|
99
90
|
"JobPoller",
|
|
100
91
|
"PollOutcome",
|
|
92
|
+
"EvalJobPoller",
|
|
101
93
|
]
|
|
102
|
-
|