synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (153) hide show
  1. synth_ai/__init__.py +13 -13
  2. synth_ai/cli/__init__.py +6 -15
  3. synth_ai/cli/commands/eval/__init__.py +6 -15
  4. synth_ai/cli/commands/eval/config.py +338 -0
  5. synth_ai/cli/commands/eval/core.py +236 -1091
  6. synth_ai/cli/commands/eval/runner.py +704 -0
  7. synth_ai/cli/commands/eval/validation.py +44 -117
  8. synth_ai/cli/commands/filter/core.py +7 -7
  9. synth_ai/cli/commands/filter/validation.py +2 -2
  10. synth_ai/cli/commands/smoke/core.py +7 -17
  11. synth_ai/cli/commands/status/__init__.py +1 -64
  12. synth_ai/cli/commands/status/client.py +50 -151
  13. synth_ai/cli/commands/status/config.py +3 -83
  14. synth_ai/cli/commands/status/errors.py +4 -13
  15. synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
  16. synth_ai/cli/commands/status/subcommands/config.py +13 -0
  17. synth_ai/cli/commands/status/subcommands/files.py +18 -63
  18. synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
  19. synth_ai/cli/commands/status/subcommands/models.py +18 -62
  20. synth_ai/cli/commands/status/subcommands/runs.py +16 -63
  21. synth_ai/cli/commands/status/subcommands/session.py +67 -172
  22. synth_ai/cli/commands/status/subcommands/summary.py +24 -32
  23. synth_ai/cli/commands/status/subcommands/utils.py +41 -0
  24. synth_ai/cli/commands/status/utils.py +16 -107
  25. synth_ai/cli/commands/train/__init__.py +18 -20
  26. synth_ai/cli/commands/train/errors.py +3 -3
  27. synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
  28. synth_ai/cli/commands/train/validation.py +7 -7
  29. synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
  30. synth_ai/cli/commands/train/verifier_validation.py +235 -0
  31. synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
  32. synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
  33. synth_ai/cli/demo_apps/math/config.toml +0 -1
  34. synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
  35. synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
  36. synth_ai/cli/lib/apps/task_app.py +12 -13
  37. synth_ai/cli/lib/task_app_discovery.py +6 -6
  38. synth_ai/cli/lib/train_cfgs.py +10 -10
  39. synth_ai/cli/task_apps/__init__.py +11 -0
  40. synth_ai/cli/task_apps/commands.py +7 -15
  41. synth_ai/core/env.py +12 -1
  42. synth_ai/core/errors.py +1 -2
  43. synth_ai/core/integrations/cloudflare.py +209 -33
  44. synth_ai/core/tracing_v3/abstractions.py +46 -0
  45. synth_ai/data/__init__.py +3 -30
  46. synth_ai/data/enums.py +1 -20
  47. synth_ai/data/rewards.py +100 -3
  48. synth_ai/products/graph_evolve/__init__.py +1 -2
  49. synth_ai/products/graph_evolve/config.py +16 -16
  50. synth_ai/products/graph_evolve/converters/__init__.py +3 -3
  51. synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
  52. synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
  53. synth_ai/products/graph_gepa/__init__.py +23 -0
  54. synth_ai/products/graph_gepa/converters/__init__.py +19 -0
  55. synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
  56. synth_ai/sdk/__init__.py +45 -35
  57. synth_ai/sdk/api/eval/__init__.py +33 -0
  58. synth_ai/sdk/api/eval/job.py +732 -0
  59. synth_ai/sdk/api/research_agent/__init__.py +276 -66
  60. synth_ai/sdk/api/train/builders.py +181 -0
  61. synth_ai/sdk/api/train/cli.py +41 -33
  62. synth_ai/sdk/api/train/configs/__init__.py +6 -4
  63. synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
  64. synth_ai/sdk/api/train/configs/rl.py +264 -16
  65. synth_ai/sdk/api/train/configs/sft.py +165 -1
  66. synth_ai/sdk/api/train/graph_validators.py +12 -12
  67. synth_ai/sdk/api/train/graphgen.py +169 -51
  68. synth_ai/sdk/api/train/graphgen_models.py +95 -45
  69. synth_ai/sdk/api/train/local_api.py +10 -0
  70. synth_ai/sdk/api/train/pollers.py +36 -0
  71. synth_ai/sdk/api/train/prompt_learning.py +390 -60
  72. synth_ai/sdk/api/train/rl.py +41 -5
  73. synth_ai/sdk/api/train/sft.py +2 -0
  74. synth_ai/sdk/api/train/task_app.py +20 -0
  75. synth_ai/sdk/api/train/validators.py +17 -17
  76. synth_ai/sdk/graphs/completions.py +239 -33
  77. synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
  78. synth_ai/sdk/learning/__init__.py +35 -5
  79. synth_ai/sdk/learning/context_learning_client.py +531 -0
  80. synth_ai/sdk/learning/context_learning_types.py +294 -0
  81. synth_ai/sdk/learning/prompt_learning_client.py +1 -1
  82. synth_ai/sdk/learning/prompt_learning_types.py +2 -1
  83. synth_ai/sdk/learning/rl/__init__.py +0 -4
  84. synth_ai/sdk/learning/rl/contracts.py +0 -4
  85. synth_ai/sdk/localapi/__init__.py +40 -0
  86. synth_ai/sdk/localapi/apps/__init__.py +28 -0
  87. synth_ai/sdk/localapi/client.py +10 -0
  88. synth_ai/sdk/localapi/contracts.py +10 -0
  89. synth_ai/sdk/localapi/helpers.py +519 -0
  90. synth_ai/sdk/localapi/rollouts.py +93 -0
  91. synth_ai/sdk/localapi/server.py +29 -0
  92. synth_ai/sdk/localapi/template.py +49 -0
  93. synth_ai/sdk/streaming/handlers.py +6 -6
  94. synth_ai/sdk/streaming/streamer.py +10 -6
  95. synth_ai/sdk/task/__init__.py +18 -5
  96. synth_ai/sdk/task/apps/__init__.py +37 -1
  97. synth_ai/sdk/task/client.py +9 -1
  98. synth_ai/sdk/task/config.py +6 -11
  99. synth_ai/sdk/task/contracts.py +137 -95
  100. synth_ai/sdk/task/in_process.py +32 -22
  101. synth_ai/sdk/task/in_process_runner.py +9 -4
  102. synth_ai/sdk/task/rubrics/__init__.py +2 -3
  103. synth_ai/sdk/task/rubrics/loaders.py +4 -4
  104. synth_ai/sdk/task/rubrics/strict.py +3 -4
  105. synth_ai/sdk/task/server.py +76 -16
  106. synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
  107. synth_ai/sdk/task/validators.py +34 -49
  108. synth_ai/sdk/training/__init__.py +7 -16
  109. synth_ai/sdk/tunnels/__init__.py +118 -0
  110. synth_ai/sdk/tunnels/cleanup.py +83 -0
  111. synth_ai/sdk/tunnels/ports.py +120 -0
  112. synth_ai/sdk/tunnels/tunneled_api.py +363 -0
  113. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
  114. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
  115. synth_ai/cli/commands/baseline/__init__.py +0 -12
  116. synth_ai/cli/commands/baseline/core.py +0 -636
  117. synth_ai/cli/commands/baseline/list.py +0 -94
  118. synth_ai/cli/commands/eval/errors.py +0 -81
  119. synth_ai/cli/commands/status/formatters.py +0 -164
  120. synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
  121. synth_ai/cli/commands/status/subcommands/usage.py +0 -203
  122. synth_ai/cli/commands/train/judge_validation.py +0 -305
  123. synth_ai/cli/usage.py +0 -159
  124. synth_ai/data/specs.py +0 -36
  125. synth_ai/sdk/api/research_agent/cli.py +0 -428
  126. synth_ai/sdk/api/research_agent/config.py +0 -357
  127. synth_ai/sdk/api/research_agent/job.py +0 -717
  128. synth_ai/sdk/baseline/__init__.py +0 -25
  129. synth_ai/sdk/baseline/config.py +0 -209
  130. synth_ai/sdk/baseline/discovery.py +0 -216
  131. synth_ai/sdk/baseline/execution.py +0 -154
  132. synth_ai/sdk/judging/__init__.py +0 -15
  133. synth_ai/sdk/judging/base.py +0 -24
  134. synth_ai/sdk/judging/client.py +0 -191
  135. synth_ai/sdk/judging/types.py +0 -42
  136. synth_ai/sdk/research_agent/__init__.py +0 -34
  137. synth_ai/sdk/research_agent/container_builder.py +0 -328
  138. synth_ai/sdk/research_agent/container_spec.py +0 -198
  139. synth_ai/sdk/research_agent/defaults.py +0 -34
  140. synth_ai/sdk/research_agent/results_collector.py +0 -69
  141. synth_ai/sdk/specs/__init__.py +0 -46
  142. synth_ai/sdk/specs/dataclasses.py +0 -149
  143. synth_ai/sdk/specs/loader.py +0 -144
  144. synth_ai/sdk/specs/serializer.py +0 -199
  145. synth_ai/sdk/specs/validation.py +0 -250
  146. synth_ai/sdk/tracing/__init__.py +0 -39
  147. synth_ai/sdk/usage/__init__.py +0 -37
  148. synth_ai/sdk/usage/client.py +0 -171
  149. synth_ai/sdk/usage/models.py +0 -261
  150. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
  151. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
  152. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
  153. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  This module provides utilities for task apps to:
4
4
  1. Extract trace_correlation_id from rollout requests
5
- 2. Include trace_correlation_id in rollout responses (3 required locations)
5
+ 2. Include trace_correlation_id in rollout responses (top-level, metadata, trace)
6
6
 
7
7
  See monorepo/trace_creation_and_judgement.txt "Fatal Guards" section for requirements.
8
8
  """
@@ -101,6 +101,24 @@ def extract_trace_correlation_id(
101
101
 
102
102
  try:
103
103
  parsed = urlparse(inference_url)
104
+
105
+ # 1. Try path-based extraction first (OpenAI SDK compatible format):
106
+ # /v1/{trial_id}/{correlation_id}/chat/completions
107
+ path_segments = [s for s in parsed.path.split("/") if s]
108
+ if len(path_segments) >= 2:
109
+ # Check if path ends with chat/completions
110
+ if path_segments[-2:] == ["chat", "completions"] and len(path_segments) >= 3:
111
+ # correlation_id is the segment before chat/completions
112
+ potential_cid = path_segments[-3]
113
+ # Verify it looks like a correlation ID (starts with trace_ or cid_)
114
+ if potential_cid.startswith("trace_") or potential_cid.startswith("cid_"):
115
+ logger.info(
116
+ "extract_trace_correlation_id: extracted from URL path=%s",
117
+ potential_cid,
118
+ )
119
+ return potential_cid.strip()
120
+
121
+ # 2. Fall back to query param extraction (legacy format)
104
122
  query_params = parse_qs(parsed.query or "")
105
123
  # Try multiple possible query param names
106
124
  for param_name in ["cid", "trace_correlation_id", "trace"]:
@@ -193,11 +211,11 @@ def include_trace_correlation_id_in_response(
193
211
  ) -> dict[str, Any]:
194
212
  """
195
213
  Include trace_correlation_id in all required locations of rollout response.
196
-
197
- Required locations (per Fatal Guards section):
214
+
215
+ Required locations (trace-only):
198
216
  1. Top-level response["trace_correlation_id"]
199
217
  2. response["pipeline_metadata"]["trace_correlation_id"]
200
- 3. Each trajectory["trace_correlation_id"]
218
+ 3. response["trace"]["metadata"]["trace_correlation_id"] (and session_trace metadata if present)
201
219
 
202
220
  Args:
203
221
  response_data: RolloutResponse dict (from .model_dump())
@@ -238,32 +256,42 @@ def include_trace_correlation_id_in_response(
238
256
  trace_correlation_id
239
257
  )
240
258
 
241
- # 3. Add to each trajectory (REQUIRED)
242
- trajectories = response_data.get("trajectories", [])
243
- if isinstance(trajectories, list):
244
- for idx, traj in enumerate(trajectories):
245
- if isinstance(traj, dict) and "trace_correlation_id" not in traj:
246
- traj["trace_correlation_id"] = trace_correlation_id
247
- logger.debug(
248
- "include_trace_correlation_id: added to trajectory[%d] run_id=%s cid=%s",
249
- idx,
250
- run_id,
251
- trace_correlation_id
252
- )
253
-
259
+ # 3. Add to trace metadata (REQUIRED)
260
+ trace_block = response_data.get("trace")
261
+ if isinstance(trace_block, dict):
262
+ trace_meta = trace_block.get("metadata")
263
+ if not isinstance(trace_meta, dict):
264
+ trace_meta = {}
265
+ trace_block["metadata"] = trace_meta
266
+ if "trace_correlation_id" not in trace_meta:
267
+ trace_meta["trace_correlation_id"] = trace_correlation_id
268
+ corr_ids = trace_meta.get("correlation_ids")
269
+ if isinstance(corr_ids, dict):
270
+ corr_map = dict(corr_ids)
271
+ else:
272
+ corr_map = {}
273
+ corr_map.setdefault("trace_correlation_id", trace_correlation_id)
274
+ trace_meta["correlation_ids"] = corr_map
275
+
276
+ session_trace = trace_block.get("session_trace")
277
+ if isinstance(session_trace, dict):
278
+ session_meta = session_trace.get("metadata")
279
+ if not isinstance(session_meta, dict):
280
+ session_meta = {}
281
+ session_trace["metadata"] = session_meta
282
+ session_meta.setdefault("trace_correlation_id", trace_correlation_id)
283
+
254
284
  logger.info(
255
285
  "include_trace_correlation_id: completed run_id=%s cid=%s "
256
- "added to %d locations (top-level, metadata, %d trajectories)",
286
+ "added to top-level, metadata, and trace",
257
287
  run_id,
258
288
  trace_correlation_id,
259
- 2 + len(trajectories),
260
- len(trajectories)
261
289
  )
262
290
 
263
291
  return response_data
264
292
 
265
293
 
266
- def build_trajectory_trace(
294
+ def build_trace_payload(
267
295
  messages: list[dict[str, Any]],
268
296
  response: dict[str, Any] | None = None,
269
297
  *,
@@ -272,11 +300,7 @@ def build_trajectory_trace(
272
300
  metadata: dict[str, Any] | None = None,
273
301
  ) -> dict[str, Any]:
274
302
  """
275
- Build a trajectory-level trace with event_history for trace strict mode.
276
-
277
- This creates the trace structure required by monorepo's trace_validation.py:
278
- - trajectory.trace.event_history must be non-empty
279
- - event_history contains LM call records for input/output extraction
303
+ Build a v3 trace payload with event_history for trace-only responses.
280
304
 
281
305
  Args:
282
306
  messages: The messages sent to the LLM (input)
@@ -286,160 +310,178 @@ def build_trajectory_trace(
286
310
  metadata: Optional additional metadata
287
311
 
288
312
  Returns:
289
- A trace dict with event_history suitable for trajectory.trace
290
-
291
- Example:
292
- trace = build_trajectory_trace(
293
- messages=[{"role": "user", "content": "Hello"}],
294
- response={"choices": [{"message": {"content": "Hi!"}}]},
295
- correlation_id="trace_abc123",
296
- )
297
- trajectory = RolloutTrajectory(..., trace=trace)
313
+ A trace dict with event_history suitable for RolloutResponse.trace
298
314
  """
299
315
  import uuid
300
316
  from datetime import datetime
301
317
 
302
- # Build event_history with LM call record
303
318
  event_history: list[dict[str, Any]] = []
304
319
 
305
- # Create an LM call event (the primary event type for input/output extraction)
306
- lm_event: dict[str, Any] = {
320
+ llm_response: dict[str, Any] = {}
321
+ if isinstance(response, dict):
322
+ if "message" in response:
323
+ llm_response = dict(response)
324
+ elif "choices" in response and isinstance(response.get("choices"), list) and response["choices"]:
325
+ first_choice = response["choices"][0] if isinstance(response["choices"][0], dict) else {}
326
+ llm_response = {
327
+ "message": first_choice.get("message") if isinstance(first_choice, dict) else {},
328
+ "usage": response.get("usage", {}),
329
+ "finish_reason": first_choice.get("finish_reason") if isinstance(first_choice, dict) else None,
330
+ }
331
+ else:
332
+ llm_response = dict(response)
333
+
334
+ llm_event: dict[str, Any] = {
335
+ "type": "lm_call",
307
336
  "event_type": "lm_call",
308
337
  "timestamp": datetime.now(UTC).isoformat(),
309
- "call_record": {
310
- "messages": messages,
311
- "response": response or {},
312
- },
338
+ "llm_request": {"messages": messages},
339
+ "llm_response": llm_response,
313
340
  }
314
341
 
315
342
  # Add correlation ID if provided
316
343
  if correlation_id:
317
- lm_event["correlation_id"] = correlation_id
344
+ llm_event["correlation_id"] = correlation_id
318
345
 
319
- event_history.append(lm_event)
346
+ event_history.append(llm_event)
347
+
348
+ trace_metadata: dict[str, Any] = dict(metadata or {})
349
+ trace_metadata.setdefault("session_id", session_id or str(uuid.uuid4()))
350
+ if correlation_id:
351
+ trace_metadata.setdefault("trace_correlation_id", correlation_id)
352
+ corr_ids = trace_metadata.get("correlation_ids")
353
+ if isinstance(corr_ids, dict):
354
+ corr_map = dict(corr_ids)
355
+ else:
356
+ corr_map = {}
357
+ corr_map.setdefault("trace_correlation_id", correlation_id)
358
+ trace_metadata["correlation_ids"] = corr_map
320
359
 
321
360
  trace: dict[str, Any] = {
322
- "session_id": session_id or str(uuid.uuid4()),
361
+ "schema_version": "3.0",
323
362
  "event_history": event_history,
324
- "created_at": datetime.now(UTC).isoformat(),
363
+ "markov_blanket_message_history": [],
364
+ "metadata": trace_metadata,
325
365
  }
326
366
 
327
- if correlation_id:
328
- trace["correlation_id"] = correlation_id
329
-
330
- if metadata:
331
- trace["metadata"] = metadata
332
-
333
367
  logger.debug(
334
- "build_trajectory_trace: created trace with %d events, session_id=%s, cid=%s",
368
+ "build_trace_payload: created trace with %d events, session_id=%s, cid=%s",
335
369
  len(event_history),
336
- trace["session_id"],
370
+ trace_metadata.get("session_id"),
337
371
  correlation_id,
338
372
  )
339
373
 
340
374
  return trace
341
375
 
342
376
 
343
- def include_event_history_in_trajectories(
377
+ def build_trajectory_trace(
378
+ messages: list[dict[str, Any]],
379
+ response: dict[str, Any] | None = None,
380
+ *,
381
+ correlation_id: str | None = None,
382
+ session_id: str | None = None,
383
+ metadata: dict[str, Any] | None = None,
384
+ ) -> dict[str, Any]:
385
+ """Backward-compatible alias for build_trace_payload."""
386
+
387
+ return build_trace_payload(
388
+ messages=messages,
389
+ response=response,
390
+ correlation_id=correlation_id,
391
+ session_id=session_id,
392
+ metadata=metadata,
393
+ )
394
+
395
+
396
+ def include_event_history_in_response(
344
397
  response_data: dict[str, Any],
345
- messages_by_trajectory: list[list[dict[str, Any]]] | None = None,
346
- responses_by_trajectory: list[dict[str, Any]] | None = None,
398
+ messages: list[dict[str, Any]] | None = None,
399
+ response: dict[str, Any] | None = None,
347
400
  *,
348
401
  run_id: str,
349
402
  correlation_id: str | None = None,
350
403
  ) -> dict[str, Any]:
351
404
  """
352
- Ensure all trajectories have trace.event_history for trace strict mode.
353
-
354
- This satisfies monorepo's trace_validation.py requirement:
355
- - validate_response_has_hydrated_trace() checks for event_history
405
+ Ensure response.trace includes a v3 event_history payload.
356
406
 
357
407
  Args:
358
408
  response_data: RolloutResponse dict (from .model_dump())
359
- messages_by_trajectory: List of messages for each trajectory (for building event_history)
360
- responses_by_trajectory: List of LLM responses for each trajectory
409
+ messages: Messages for the LLM call (for building event_history)
410
+ response: LLM response payload
361
411
  run_id: Rollout run_id for logging
362
412
  correlation_id: Trace correlation ID
363
413
 
364
414
  Returns:
365
- Modified response_data with event_history in each trajectory.trace
415
+ Modified response_data with event_history in response.trace
366
416
  """
367
- trajectories = response_data.get("trajectories", [])
368
- if not isinstance(trajectories, list):
369
- logger.warning(
370
- "include_event_history_in_trajectories: trajectories is not a list for run_id=%s",
371
- run_id,
372
- )
373
- return response_data
417
+ trace_block = response_data.get("trace")
418
+ if not isinstance(trace_block, dict):
419
+ trace_block = {}
420
+ response_data["trace"] = trace_block
374
421
 
375
- for idx, traj in enumerate(trajectories):
376
- if not isinstance(traj, dict):
377
- continue
422
+ event_history = trace_block.get("event_history")
423
+ session_trace = trace_block.get("session_trace")
424
+ if not event_history and isinstance(session_trace, dict):
425
+ event_history = session_trace.get("event_history")
378
426
 
379
- # Get existing trace or create new one
380
- trace = traj.get("trace")
381
- if not isinstance(trace, dict):
382
- trace = {}
383
- traj["trace"] = trace
427
+ if isinstance(event_history, list) and event_history:
428
+ return response_data
384
429
 
385
- # Check if event_history already exists and is non-empty
386
- event_history = trace.get("event_history")
387
- if isinstance(event_history, list) and len(event_history) > 0:
388
- logger.debug(
389
- "include_event_history_in_trajectories: trajectory[%d] already has "
390
- "%d events, skipping run_id=%s",
391
- idx,
392
- len(event_history),
393
- run_id,
394
- )
395
- continue
430
+ new_trace = build_trace_payload(
431
+ messages=messages or [],
432
+ response=response,
433
+ correlation_id=correlation_id,
434
+ metadata={"run_id": run_id},
435
+ )
396
436
 
397
- # Build event_history from provided messages/responses
398
- messages = (
399
- messages_by_trajectory[idx]
400
- if messages_by_trajectory and idx < len(messages_by_trajectory)
401
- else []
402
- )
403
- response = (
404
- responses_by_trajectory[idx]
405
- if responses_by_trajectory and idx < len(responses_by_trajectory)
406
- else None
407
- )
437
+ # Merge new trace payload into the existing trace block.
438
+ trace_meta = trace_block.get("metadata")
439
+ if isinstance(trace_meta, dict):
440
+ merged_meta = dict(new_trace.get("metadata", {}))
441
+ merged_meta.update(trace_meta)
442
+ trace_block["metadata"] = merged_meta
443
+ else:
444
+ trace_block["metadata"] = new_trace.get("metadata", {})
408
445
 
409
- # If no messages provided, try to extract from trajectory steps
410
- if not messages:
411
- steps = traj.get("steps", [])
412
- for step in steps:
413
- if isinstance(step, dict):
414
- obs = step.get("obs", {})
415
- if isinstance(obs, dict):
416
- step_messages = obs.get("messages")
417
- if isinstance(step_messages, list):
418
- messages = step_messages
419
- break
420
-
421
- # Build the trace with event_history
422
- new_trace = build_trajectory_trace(
423
- messages=messages,
424
- response=response,
425
- correlation_id=correlation_id or traj.get("trace_correlation_id"),
426
- metadata={"run_id": run_id, "trajectory_index": idx},
427
- )
446
+ trace_block.setdefault("schema_version", new_trace.get("schema_version"))
447
+ trace_block["event_history"] = new_trace.get("event_history", [])
448
+ trace_block.setdefault(
449
+ "markov_blanket_message_history",
450
+ new_trace.get("markov_blanket_message_history", []),
451
+ )
428
452
 
429
- # Merge with existing trace (preserve existing fields)
430
- trace.update(new_trace)
431
-
432
- logger.info(
433
- "include_event_history_in_trajectories: added event_history to "
434
- "trajectory[%d] run_id=%s events=%d",
435
- idx,
436
- run_id,
437
- len(trace.get("event_history", [])),
438
- )
453
+ if isinstance(session_trace, dict) and "event_history" not in session_trace:
454
+ session_trace["event_history"] = trace_block["event_history"]
439
455
 
456
+ logger.info(
457
+ "include_event_history_in_response: added event_history run_id=%s events=%d",
458
+ run_id,
459
+ len(trace_block.get("event_history", [])),
460
+ )
440
461
  return response_data
441
462
 
442
463
 
464
+ def include_event_history_in_trajectories(
465
+ response_data: dict[str, Any],
466
+ messages_by_trajectory: list[list[dict[str, Any]]] | None = None,
467
+ responses_by_trajectory: list[dict[str, Any]] | None = None,
468
+ *,
469
+ run_id: str,
470
+ correlation_id: str | None = None,
471
+ ) -> dict[str, Any]:
472
+ """Backward-compatible alias for include_event_history_in_response."""
473
+
474
+ messages = messages_by_trajectory[0] if messages_by_trajectory else None
475
+ response = responses_by_trajectory[0] if responses_by_trajectory else None
476
+ return include_event_history_in_response(
477
+ response_data,
478
+ messages=messages,
479
+ response=response,
480
+ run_id=run_id,
481
+ correlation_id=correlation_id,
482
+ )
483
+
484
+
443
485
  def verify_trace_correlation_id_in_response(
444
486
  response_data: dict[str, Any],
445
487
  expected_correlation_id: str | None,
@@ -480,15 +522,24 @@ def verify_trace_correlation_id_in_response(
480
522
  f"expected={expected_correlation_id} actual={pipeline_meta.get('trace_correlation_id') if isinstance(pipeline_meta, dict) else 'NOT_A_DICT'}"
481
523
  )
482
524
 
483
- # Check trajectories
484
- trajectories = response_data.get("trajectories", [])
485
- if isinstance(trajectories, list):
486
- for idx, traj in enumerate(trajectories):
487
- if isinstance(traj, dict) and traj.get("trace_correlation_id") != expected_correlation_id:
488
- errors.append(
489
- f"trajectory[{idx}] missing or mismatch: "
490
- f"expected={expected_correlation_id} actual={traj.get('trace_correlation_id')}"
491
- )
525
+ # Check trace metadata
526
+ trace_block = response_data.get("trace")
527
+ trace_meta_id = None
528
+ if isinstance(trace_block, dict):
529
+ trace_meta = trace_block.get("metadata")
530
+ if isinstance(trace_meta, dict):
531
+ trace_meta_id = trace_meta.get("trace_correlation_id")
532
+ if trace_meta_id != expected_correlation_id:
533
+ session_trace = trace_block.get("session_trace")
534
+ if isinstance(session_trace, dict):
535
+ session_meta = session_trace.get("metadata")
536
+ if isinstance(session_meta, dict):
537
+ trace_meta_id = session_meta.get("trace_correlation_id")
538
+ if trace_meta_id != expected_correlation_id:
539
+ errors.append(
540
+ "trace.metadata missing or mismatch: "
541
+ f"expected={expected_correlation_id} actual={trace_meta_id}"
542
+ )
492
543
 
493
544
  if errors:
494
545
  logger.error(
@@ -16,8 +16,8 @@ def validate_rollout_response_for_rl(response_data: dict[str, Any], *, warn_only
16
16
  """Validate that a task app rollout response has required fields for RL training.
17
17
 
18
18
  The backend RL trainer requires:
19
- 1. pipeline_metadata["inference_url"] at top level (with ?cid= for trace correlation)
20
- 2. Each step's info.meta["inference_url"] must be present (nested structure!)
19
+ 1. A v3 trace with event_history (preferred), OR
20
+ 2. pipeline_metadata["inference_url"] with ?cid= for trace hydration fallback
21
21
 
22
22
  Args:
23
23
  response_data: The rollout response dict from task app
@@ -31,16 +31,43 @@ def validate_rollout_response_for_rl(response_data: dict[str, Any], *, warn_only
31
31
  """
32
32
  issues = []
33
33
 
34
- # Check pipeline_metadata
34
+ trace_block = response_data.get("trace")
35
+ event_history = None
36
+ if isinstance(trace_block, dict):
37
+ event_history = trace_block.get("event_history")
38
+ if not event_history and isinstance(trace_block.get("session_trace"), dict):
39
+ event_history = trace_block["session_trace"].get("event_history")
40
+
41
+ has_event_history = isinstance(event_history, list) and len(event_history) > 0
42
+
43
+ trace_correlation_id = response_data.get("trace_correlation_id")
44
+ if not trace_correlation_id and isinstance(trace_block, dict):
45
+ trace_meta = trace_block.get("metadata")
46
+ if isinstance(trace_meta, dict):
47
+ trace_correlation_id = trace_meta.get("trace_correlation_id")
48
+
49
+ if not trace_correlation_id:
50
+ issues.append(
51
+ "Missing trace_correlation_id (top-level or trace.metadata). "
52
+ "RL trainer requires this to link traces."
53
+ )
54
+
55
+ if not has_event_history:
56
+ issues.append(
57
+ "trace.event_history is missing or empty. "
58
+ "Return a v3 trace or provide inference_url for hydration."
59
+ )
60
+
61
+ # Check pipeline_metadata inference_url only when trace is missing/empty
35
62
  pipeline_metadata = response_data.get("pipeline_metadata")
36
- if not isinstance(pipeline_metadata, dict):
37
- issues.append("Missing or invalid 'pipeline_metadata' (required for RL training)")
38
- else:
63
+ inference_url = None
64
+ if isinstance(pipeline_metadata, dict):
39
65
  inference_url = pipeline_metadata.get("inference_url")
66
+ if not has_event_history:
40
67
  if not inference_url:
41
68
  issues.append(
42
69
  "pipeline_metadata['inference_url'] is missing. "
43
- "RL trainer requires this field to extract traces."
70
+ "RL trainer needs this to hydrate traces when event_history is absent."
44
71
  )
45
72
  elif not isinstance(inference_url, str):
46
73
  issues.append(
@@ -52,48 +79,6 @@ def validate_rollout_response_for_rl(response_data: dict[str, Any], *, warn_only
52
79
  f"Got: {inference_url[:80]}..."
53
80
  )
54
81
 
55
- # Check trajectories and steps
56
- trajectories = response_data.get("trajectories", [])
57
- if not trajectories:
58
- issues.append("No trajectories found in response")
59
-
60
- for traj_idx, trajectory in enumerate(trajectories):
61
- if not isinstance(trajectory, dict):
62
- continue
63
-
64
- steps = trajectory.get("steps", [])
65
- for step_idx, step in enumerate(steps):
66
- if not isinstance(step, dict):
67
- continue
68
-
69
- step_info = step.get("info", {})
70
- if not isinstance(step_info, dict):
71
- issues.append(
72
- f"trajectory[{traj_idx}].steps[{step_idx}].info is not a dict"
73
- )
74
- continue
75
-
76
- # Check for nested meta.inference_url (backend expects this structure!)
77
- step_meta = step_info.get("meta", {})
78
- if not isinstance(step_meta, dict):
79
- issues.append(
80
- f"trajectory[{traj_idx}].steps[{step_idx}].info.meta is missing or not a dict. "
81
- f"RL trainer expects nested structure: info.meta.inference_url"
82
- )
83
- continue
84
-
85
- step_inference_url = step_meta.get("inference_url")
86
- if not step_inference_url:
87
- issues.append(
88
- f"trajectory[{traj_idx}].steps[{step_idx}].info.meta['inference_url'] is missing. "
89
- f"RL trainer needs this for trace extraction (nested structure required!)"
90
- )
91
- elif not isinstance(step_inference_url, str):
92
- issues.append(
93
- f"trajectory[{traj_idx}].steps[{step_idx}].info.meta['inference_url'] must be a string, "
94
- f"got: {type(step_inference_url).__name__}"
95
- )
96
-
97
82
  if issues and not warn_only:
98
83
  error_msg = "Task app response validation failed for RL training:\n" + "\n".join(
99
84
  f" - {issue}" for issue in issues
@@ -4,7 +4,7 @@ This module provides high-level APIs for running training jobs:
4
4
  - PromptLearningJob: GEPA and MIPRO prompt optimization
5
5
  - SFTJob: Supervised fine-tuning
6
6
  - RLJob: Reinforcement learning (GSPO, GRPO, PPO, etc.)
7
- - GraphGenJob: Automated Design of Agentic Systems (simplified workflows API)
7
+ - GraphGenJob: Graph Opt (simplified workflows API)
8
8
 
9
9
  Example:
10
10
  from synth_ai.sdk.training import PromptLearningJob, RLJob, GraphGenJob
@@ -30,7 +30,7 @@ Example:
30
30
  from __future__ import annotations
31
31
 
32
32
  # Pollers and utilities
33
- from synth_ai.sdk.api.train.pollers import JobPoller, PollOutcome, RLJobPoller
33
+ from synth_ai.sdk.api.train.pollers import JobPoller, PollOutcome, RLJobPoller, EvalJobPoller
34
34
 
35
35
  # Re-export from existing locations
36
36
  from synth_ai.sdk.api.train.prompt_learning import (
@@ -41,7 +41,7 @@ from synth_ai.sdk.api.train.prompt_learning import (
41
41
  from synth_ai.sdk.api.train.rl import RLJob, RLJobConfig
42
42
  from synth_ai.sdk.api.train.sft import SFTJob
43
43
 
44
- # GraphGen (formerly GraphGen)
44
+ # GraphGen (Graph Opt)
45
45
  from synth_ai.sdk.api.train.graphgen import GraphGenJob, GraphGenJobResult, GraphGenSubmitResult
46
46
  from synth_ai.sdk.api.train.graphgen_models import (
47
47
  GraphGenJobConfig,
@@ -49,16 +49,7 @@ from synth_ai.sdk.api.train.graphgen_models import (
49
49
  GraphGenTask,
50
50
  GraphGenGoldOutput,
51
51
  GraphGenRubric,
52
- GraphGenJudgeConfig,
53
- load_graphgen_taskset,
54
- parse_graphgen_taskset,
55
- # GraphGen aliases
56
- GraphGenJobConfig,
57
- GraphGenTaskSet,
58
- GraphGenTask,
59
- GraphGenGoldOutput,
60
- GraphGenRubric,
61
- GraphGenJudgeConfig,
52
+ GraphGenVerifierConfig,
62
53
  load_graphgen_taskset,
63
54
  parse_graphgen_taskset,
64
55
  )
@@ -80,7 +71,7 @@ __all__ = [
80
71
  "GraphGenTask",
81
72
  "GraphGenGoldOutput",
82
73
  "GraphGenRubric",
83
- "GraphGenJudgeConfig",
74
+ "GraphGenVerifierConfig",
84
75
  "load_graphgen_taskset",
85
76
  "parse_graphgen_taskset",
86
77
  # GraphGen (legacy aliases)
@@ -92,11 +83,11 @@ __all__ = [
92
83
  "GraphGenTask",
93
84
  "GraphGenGoldOutput",
94
85
  "GraphGenRubric",
95
- "GraphGenJudgeConfig",
86
+ "GraphGenVerifierConfig",
96
87
  "load_graphgen_taskset",
97
88
  "parse_graphgen_taskset",
98
89
  # Utils
99
90
  "JobPoller",
100
91
  "PollOutcome",
92
+ "EvalJobPoller",
101
93
  ]
102
-