synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (169) hide show
  1. examples/baseline/banking77_baseline.py +204 -0
  2. examples/baseline/crafter_baseline.py +407 -0
  3. examples/baseline/pokemon_red_baseline.py +326 -0
  4. examples/baseline/simple_baseline.py +56 -0
  5. examples/baseline/warming_up_to_rl_baseline.py +239 -0
  6. examples/blog_posts/gepa/README.md +355 -0
  7. examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
  8. examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
  9. examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
  10. examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
  11. examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
  12. examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
  13. examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
  14. examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
  15. examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
  16. examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
  17. examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
  18. examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
  19. examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
  20. examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
  21. examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
  22. examples/blog_posts/gepa/gepa_baseline.py +204 -0
  23. examples/blog_posts/gepa/query_prompts_example.py +97 -0
  24. examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
  25. examples/blog_posts/gepa/task_apps.py +105 -0
  26. examples/blog_posts/gepa/test_gepa_local.sh +67 -0
  27. examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
  28. examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
  29. examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
  30. examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
  31. examples/blog_posts/pokemon_vl/extract_images.py +239 -0
  32. examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
  33. examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
  34. examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
  35. examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
  36. examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
  37. examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
  38. examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
  39. examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
  40. examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
  41. examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
  42. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
  43. examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
  44. examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
  45. examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
  46. examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
  47. examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
  48. examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
  49. examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
  50. examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
  51. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
  52. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
  53. examples/rl/configs/rl_from_base_qwen17.toml +1 -0
  54. examples/swe/task_app/hosted/inference/openai_client.py +0 -34
  55. examples/swe/task_app/hosted/policy_routes.py +17 -0
  56. examples/swe/task_app/hosted/rollout.py +4 -2
  57. examples/task_apps/banking77/__init__.py +6 -0
  58. examples/task_apps/banking77/banking77_task_app.py +841 -0
  59. examples/task_apps/banking77/deploy_wrapper.py +46 -0
  60. examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
  61. examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
  62. examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
  63. examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
  64. examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
  65. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
  66. examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
  67. examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
  68. examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
  69. examples/task_apps/gepa_benchmarks/__init__.py +7 -0
  70. examples/task_apps/gepa_benchmarks/common.py +260 -0
  71. examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
  72. examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
  73. examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
  74. examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
  75. examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
  76. examples/task_apps/pokemon_red/task_app.py +254 -36
  77. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
  78. examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
  84. examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
  85. synth_ai/api/train/builders.py +90 -1
  86. synth_ai/api/train/cli.py +396 -21
  87. synth_ai/api/train/config_finder.py +13 -2
  88. synth_ai/api/train/configs/__init__.py +15 -1
  89. synth_ai/api/train/configs/prompt_learning.py +442 -0
  90. synth_ai/api/train/configs/rl.py +29 -0
  91. synth_ai/api/train/task_app.py +1 -1
  92. synth_ai/api/train/validators.py +277 -0
  93. synth_ai/baseline/__init__.py +25 -0
  94. synth_ai/baseline/config.py +209 -0
  95. synth_ai/baseline/discovery.py +214 -0
  96. synth_ai/baseline/execution.py +146 -0
  97. synth_ai/cli/__init__.py +85 -17
  98. synth_ai/cli/__main__.py +0 -0
  99. synth_ai/cli/claude.py +70 -0
  100. synth_ai/cli/codex.py +84 -0
  101. synth_ai/cli/commands/__init__.py +1 -0
  102. synth_ai/cli/commands/baseline/__init__.py +12 -0
  103. synth_ai/cli/commands/baseline/core.py +637 -0
  104. synth_ai/cli/commands/baseline/list.py +93 -0
  105. synth_ai/cli/commands/eval/core.py +13 -10
  106. synth_ai/cli/commands/filter/core.py +53 -17
  107. synth_ai/cli/commands/help/core.py +0 -1
  108. synth_ai/cli/commands/smoke/__init__.py +7 -0
  109. synth_ai/cli/commands/smoke/core.py +1436 -0
  110. synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
  111. synth_ai/cli/commands/status/subcommands/usage.py +203 -0
  112. synth_ai/cli/commands/train/judge_schemas.py +1 -0
  113. synth_ai/cli/commands/train/judge_validation.py +1 -0
  114. synth_ai/cli/commands/train/validation.py +0 -57
  115. synth_ai/cli/demo.py +35 -3
  116. synth_ai/cli/deploy/__init__.py +40 -25
  117. synth_ai/cli/deploy.py +162 -0
  118. synth_ai/cli/legacy_root_backup.py +14 -8
  119. synth_ai/cli/opencode.py +107 -0
  120. synth_ai/cli/root.py +9 -5
  121. synth_ai/cli/task_app_deploy.py +1 -1
  122. synth_ai/cli/task_apps.py +53 -53
  123. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
  124. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
  125. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
  126. synth_ai/judge_schemas.py +1 -0
  127. synth_ai/learning/__init__.py +10 -0
  128. synth_ai/learning/prompt_learning_client.py +276 -0
  129. synth_ai/learning/prompt_learning_types.py +184 -0
  130. synth_ai/pricing/__init__.py +2 -0
  131. synth_ai/pricing/model_pricing.py +57 -0
  132. synth_ai/streaming/handlers.py +53 -4
  133. synth_ai/streaming/streamer.py +19 -0
  134. synth_ai/task/apps/__init__.py +1 -0
  135. synth_ai/task/config.py +2 -0
  136. synth_ai/task/tracing_utils.py +25 -25
  137. synth_ai/task/validators.py +44 -8
  138. synth_ai/task_app_cfgs.py +21 -0
  139. synth_ai/tracing_v3/config.py +162 -19
  140. synth_ai/tracing_v3/constants.py +1 -1
  141. synth_ai/tracing_v3/db_config.py +24 -38
  142. synth_ai/tracing_v3/storage/config.py +47 -13
  143. synth_ai/tracing_v3/storage/factory.py +3 -3
  144. synth_ai/tracing_v3/turso/daemon.py +113 -11
  145. synth_ai/tracing_v3/turso/native_manager.py +92 -16
  146. synth_ai/types.py +8 -0
  147. synth_ai/urls.py +11 -0
  148. synth_ai/utils/__init__.py +30 -1
  149. synth_ai/utils/agents.py +74 -0
  150. synth_ai/utils/bin.py +39 -0
  151. synth_ai/utils/cli.py +149 -5
  152. synth_ai/utils/env.py +17 -17
  153. synth_ai/utils/json.py +72 -0
  154. synth_ai/utils/modal.py +283 -1
  155. synth_ai/utils/paths.py +48 -0
  156. synth_ai/utils/uvicorn.py +113 -0
  157. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
  158. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
  159. synth_ai/cli/commands/deploy/__init__.py +0 -23
  160. synth_ai/cli/commands/deploy/core.py +0 -614
  161. synth_ai/cli/commands/deploy/errors.py +0 -72
  162. synth_ai/cli/commands/deploy/validation.py +0 -11
  163. synth_ai/cli/deploy/core.py +0 -5
  164. synth_ai/cli/deploy/errors.py +0 -23
  165. synth_ai/cli/deploy/validation.py +0 -5
  166. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
  167. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
  168. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
  169. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
synth_ai/api/train/cli.py CHANGED
@@ -1,7 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import contextlib
4
5
  import importlib
6
+ import json
5
7
  import os
6
8
  import time
7
9
  from collections.abc import Callable, Mapping
@@ -27,7 +29,7 @@ from synth_ai.streaming import (
27
29
  StreamType,
28
30
  )
29
31
 
30
- from .builders import build_rl_payload, build_sft_payload
32
+ from .builders import build_prompt_learning_payload, build_rl_payload, build_sft_payload
31
33
  from .config_finder import discover_configs, prompt_for_config
32
34
  from .env_resolver import KeySpec, resolve_env
33
35
  from .task_app import check_task_app_health
@@ -45,6 +47,45 @@ from .utils import (
45
47
  validate_sft_jsonl,
46
48
  )
47
49
 
50
+ # Constants for prompt learning event types
51
+ _PROMPT_LEARNING_EVENT_BEST_PROMPT = "prompt.learning.best.prompt"
52
+ _PROMPT_LEARNING_EVENT_FINAL_RESULTS = "prompt.learning.final.results"
53
+ _PROMPT_LEARNING_EVENT_VALIDATION_SCORED = "prompt.learning.validation.scored"
54
+ _PROMPT_LEARNING_EVENT_GEPA_COMPLETE = "prompt.learning.gepa.complete"
55
+
56
+ # Constants for formatting
57
+ _MAX_TEXT_REPLACEMENTS_DISPLAY = 3 # Max number of text replacements to show in output
58
+ _RESULTS_FILE_MAX_EVENTS = 10000 # Max events to fetch for results file generation
59
+
60
+
61
+ def _format_text_replacements(obj: dict[str, Any] | None, max_display: int = _MAX_TEXT_REPLACEMENTS_DISPLAY) -> list[str]:
62
+ """Extract and format text replacements from a candidate object.
63
+
64
+ Args:
65
+ obj: Candidate object dictionary containing text_replacements
66
+ max_display: Maximum number of replacements to display
67
+
68
+ Returns:
69
+ List of formatted lines showing role and replacement text
70
+ """
71
+ lines = []
72
+ if not obj or not isinstance(obj, dict):
73
+ return lines
74
+
75
+ text_replacements = obj.get("text_replacements", [])
76
+ if not text_replacements or not isinstance(text_replacements, list):
77
+ return lines
78
+
79
+ for replacement in text_replacements[:max_display]:
80
+ if isinstance(replacement, dict):
81
+ new_text = replacement.get("new_text", "")
82
+ role = replacement.get("apply_to_role", "system")
83
+ if new_text:
84
+ lines.append(f" [{role.upper()}]: {new_text}")
85
+ lines.append("")
86
+
87
+ return lines
88
+
48
89
 
49
90
  def _discover_dataset_candidates(
50
91
  config_path: Path, limit: int = 50, timeout: float = 10.0
@@ -164,6 +205,10 @@ _DEFAULT_SFT_HIDDEN_EVENTS = {
164
205
 
165
206
  _DEFAULT_RL_HIDDEN_SUBSTRINGS = {"modal", "hatchet"}
166
207
 
208
+ _DEFAULT_PROMPT_LEARNING_HIDDEN_EVENTS = {
209
+ "prompt.learning.policy.tokens",
210
+ }
211
+
167
212
 
168
213
  def _build_stream_components(
169
214
  stream_format: str,
@@ -208,7 +253,7 @@ def _build_stream_components(
208
253
  type=click.Path(),
209
254
  help="Path to training TOML (repeatable)",
210
255
  )
211
- @click.option("--type", "train_type", type=click.Choice(["auto", "rl", "sft"]), default="auto")
256
+ @click.option("--type", "train_type", type=click.Choice(["auto", "rl", "sft", "prompt_learning"]), default="auto")
212
257
  @click.option(
213
258
  "--env-file",
214
259
  "env_files",
@@ -279,7 +324,7 @@ def train_command(
279
324
  stream_format: str,
280
325
  examples_limit: int | None,
281
326
  ) -> None:
282
- """Interactive launcher for RL / SFT jobs."""
327
+ """Interactive launcher for RL / SFT / Prompt Learning jobs."""
283
328
 
284
329
  candidates = discover_configs(
285
330
  list(config_paths), requested_type=train_type if train_type != "auto" else None
@@ -291,16 +336,16 @@ def train_command(
291
336
  )
292
337
 
293
338
  effective_type = train_type if train_type != "auto" else selection.train_type
294
- if effective_type not in {"rl", "sft"}:
339
+ if effective_type not in {"rl", "sft", "prompt_learning"}:
295
340
  effective_type = click.prompt(
296
- "Detected config type is ambiguous. Enter type", type=click.Choice(["rl", "sft"])
341
+ "Detected config type is ambiguous. Enter type", type=click.Choice(["rl", "sft", "prompt_learning"])
297
342
  )
298
343
 
299
344
  cfg_path = selection.path
300
345
  click.echo(f"Using config: {cfg_path} ({effective_type})")
301
346
 
302
347
  required_keys: list[KeySpec] = []
303
- if effective_type == "rl":
348
+ if effective_type == "rl" or effective_type == "prompt_learning":
304
349
  required_keys.append(KeySpec("SYNTH_API_KEY", "Synth API key for backend"))
305
350
  required_keys.append(
306
351
  KeySpec(
@@ -377,6 +422,19 @@ def train_command(
377
422
  poll_interval=poll_interval,
378
423
  stream_format=stream_format,
379
424
  )
425
+ elif effective_type == "prompt_learning":
426
+ handle_prompt_learning(
427
+ cfg_path=cfg_path,
428
+ backend_base=backend_base,
429
+ synth_key=synth_key,
430
+ task_url_override=task_url,
431
+ allow_experimental=allow_experimental,
432
+ dry_run=dry_run,
433
+ poll=poll,
434
+ poll_timeout=poll_timeout,
435
+ poll_interval=poll_interval,
436
+ stream_format=stream_format,
437
+ )
380
438
  else:
381
439
  dataset_override_path = Path(dataset_path).expanduser().resolve() if dataset_path else None
382
440
  handle_sft(
@@ -415,7 +473,7 @@ def _wait_for_training_file(
415
473
  if resp.status_code == 200:
416
474
  try:
417
475
  data = resp.json()
418
- except Exception:
476
+ except json.JSONDecodeError:
419
477
  data = {}
420
478
  status = str(
421
479
  data.get("status") or data.get("state") or data.get("storage_state") or "ready"
@@ -440,7 +498,7 @@ def _wait_for_training_file(
440
498
  # Auth errors won't resolve by polling - fail immediately
441
499
  try:
442
500
  error_body = resp.json()
443
- except Exception:
501
+ except json.JSONDecodeError:
444
502
  error_body = resp.text[:400]
445
503
  click.echo("\n[ERROR] Authentication failed when checking training file:")
446
504
  click.echo(f" URL: {url}")
@@ -455,7 +513,7 @@ def _wait_for_training_file(
455
513
  # Other errors - show details but keep polling
456
514
  try:
457
515
  error_body = resp.json()
458
- except Exception:
516
+ except json.JSONDecodeError:
459
517
  error_body = resp.text[:400]
460
518
  click.echo(f"[WARN] Unexpected response checking file {file_id}:")
461
519
  click.echo(f" URL: {url}")
@@ -507,7 +565,7 @@ def handle_rl(
507
565
  )
508
566
  try:
509
567
  parsed_json = vresp.json()
510
- except Exception:
568
+ except json.JSONDecodeError:
511
569
  parsed_json = None
512
570
 
513
571
  if isinstance(parsed_json, Mapping):
@@ -542,8 +600,9 @@ def handle_rl(
542
600
  )
543
601
  statuses = [attempt.get("status") for attempt in attempts]
544
602
  click.echo(f"Verification OK (candidates={cands}, statuses={statuses})")
545
- except Exception:
546
- pass
603
+ except (KeyError, ValueError, AttributeError):
604
+ # Parsing verification summary failed, but verification itself succeeded
605
+ click.echo("Verification OK")
547
606
 
548
607
  env_key = os.environ.get("ENVIRONMENT_API_KEY")
549
608
  if not env_key:
@@ -568,7 +627,8 @@ def handle_rl(
568
627
  resp = http_post(create_url, headers=headers, json_body=build.payload)
569
628
  try:
570
629
  js = resp.json()
571
- except Exception:
630
+ except json.JSONDecodeError as e:
631
+ click.echo(f"⚠️ Failed to parse JSON response: {e}")
572
632
  js = {"status": resp.status_code, "text": resp.text[:400]}
573
633
  click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
574
634
  if resp.status_code not in (200, 201):
@@ -582,11 +642,27 @@ def handle_rl(
582
642
  return
583
643
 
584
644
  click.echo("\n=== Streaming Job Progress ===")
585
- config, handlers = _build_stream_components(
586
- stream_format, hidden_event_substrings=_DEFAULT_RL_HIDDEN_SUBSTRINGS
587
- )
645
+
646
+ # Enable metrics for prompt learning
588
647
  if stream_format == "chart":
589
- click.echo("Using live loss chart (metric=train.loss)")
648
+ config = StreamConfig(
649
+ enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
650
+ event_types={
651
+ "prompt.learning.progress",
652
+ "prompt.learning.gepa.start",
653
+ "prompt.learning.gepa.complete",
654
+ },
655
+ metric_names={"gepa.transformation.mean_score"},
656
+ )
657
+ handlers = [LossCurveHandler()]
658
+ click.echo("Using live chart (metric=gepa.transformation.mean_score)")
659
+ else:
660
+ config = StreamConfig(
661
+ enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
662
+ metric_names={"gepa.transformation.mean_score"},
663
+ )
664
+ handlers = [CLIHandler(hidden_event_substrings=_DEFAULT_RL_HIDDEN_SUBSTRINGS)]
665
+
590
666
  streamer = JobStreamer(
591
667
  base_url=backend_base,
592
668
  api_key=synth_key,
@@ -758,15 +834,314 @@ def handle_sft(
758
834
  timeout_seconds=poll_timeout,
759
835
  )
760
836
  final_status = asyncio.run(streamer.stream_until_terminal())
761
- click.echo(f"Final status: {final_status.get('status', 'unknown')}")
837
+ status = final_status.get('status') if isinstance(final_status, dict) else 'unknown'
838
+ click.echo(f"Final status: {status}")
762
839
  click.echo(preview_json(final_status, limit=600))
763
840
  finally:
764
841
  if limited_path is not None:
765
- try:
842
+ with contextlib.suppress(OSError):
766
843
  limited_path.unlink(missing_ok=True)
844
+ # Clean up empty parent directory if possible
845
+ with contextlib.suppress(OSError):
767
846
  limited_path.parent.rmdir()
768
- except Exception:
769
- pass
847
+
848
+
849
+ def _save_prompt_learning_results_locally(
850
+ *,
851
+ backend_base: str,
852
+ api_key: str,
853
+ job_id: str,
854
+ config_path: Path,
855
+ ) -> None:
856
+ """Fetch events and generate results file locally after prompt learning completes."""
857
+ from datetime import datetime
858
+
859
+ try:
860
+ # Fetch all events
861
+ url = f"{backend_base}/prompt-learning/online/jobs/{job_id}/events?limit={_RESULTS_FILE_MAX_EVENTS}"
862
+ headers = {"Authorization": f"Bearer {api_key}"}
863
+ resp = http_get(url, headers=headers, timeout=30.0)
864
+
865
+ if resp.status_code != 200:
866
+ click.echo(f"⚠️ Could not fetch events to generate results file (status={resp.status_code})")
867
+ return
868
+
869
+ data = resp.json()
870
+ # Validate response structure
871
+ if not isinstance(data, dict):
872
+ click.echo(f"⚠️ Unexpected response type: {type(data).__name__}")
873
+ return
874
+
875
+ events = data.get("events", [])
876
+ if not isinstance(events, list):
877
+ click.echo(f"⚠️ Events field is not a list: {type(events).__name__}")
878
+ return
879
+
880
+ if not events:
881
+ return
882
+
883
+ # Extract key data from events
884
+ best_score = None
885
+ best_prompt = None
886
+ baseline_score = None
887
+ attempted_candidates = []
888
+ optimized_candidates = []
889
+
890
+ for event in events:
891
+ if not isinstance(event, dict):
892
+ continue # Skip malformed events
893
+
894
+ event_type = event.get("type", "")
895
+ event_data = event.get("data", {})
896
+ if not isinstance(event_data, dict):
897
+ event_data = {} # Fallback to empty dict for safety
898
+
899
+ if event_type == _PROMPT_LEARNING_EVENT_BEST_PROMPT:
900
+ best_score = event_data.get("best_score")
901
+ best_prompt = event_data.get("best_prompt")
902
+ elif event_type == _PROMPT_LEARNING_EVENT_FINAL_RESULTS:
903
+ attempted_candidates = event_data.get("attempted_candidates", [])
904
+ optimized_candidates = event_data.get("optimized_candidates", [])
905
+ elif event_type == _PROMPT_LEARNING_EVENT_VALIDATION_SCORED:
906
+ # Check if this is the baseline by checking for is_baseline flag or baseline in message
907
+ is_baseline = event_data.get("is_baseline", False)
908
+ if not is_baseline:
909
+ msg = event.get("message", "")
910
+ is_baseline = "baseline" in msg.lower()
911
+ if is_baseline:
912
+ baseline_score = event_data.get("accuracy")
913
+ elif event_type == _PROMPT_LEARNING_EVENT_GEPA_COMPLETE and best_score is None:
914
+ best_score = event_data.get("best_score")
915
+
916
+ if not (attempted_candidates or optimized_candidates):
917
+ return
918
+
919
+ # Generate formatted report
920
+ lines = []
921
+ lines.append("=" * 80)
922
+ lines.append("GEPA PROMPT LEARNING RESULTS")
923
+ lines.append("=" * 80)
924
+ lines.append(f"Job ID: {job_id}")
925
+ lines.append(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
926
+ lines.append("")
927
+ if baseline_score is not None:
928
+ lines.append(f"📊 Baseline Score: {baseline_score:.4f} ({baseline_score*100:.1f}%)")
929
+ if best_score is not None:
930
+ lines.append(f"🏆 Best Score: {best_score:.4f} ({best_score*100:.1f}%)")
931
+ if baseline_score is not None and best_score is not None:
932
+ improvement = ((best_score - baseline_score) / baseline_score) * 100 if baseline_score > 0 else 0
933
+ lines.append(f"📈 Improvement: {improvement:+.1f}% relative ({(best_score - baseline_score)*100:+.1f} pp absolute)")
934
+ lines.append("=" * 80)
935
+ lines.append("")
936
+
937
+ # Add best prompt if available
938
+ if best_prompt and isinstance(best_prompt, dict):
939
+ lines.append("🏆 BEST PROMPT")
940
+ lines.append("-" * 80)
941
+ sections = best_prompt.get("sections", [])
942
+ if not isinstance(sections, list):
943
+ sections = []
944
+ for sec in sections:
945
+ if not isinstance(sec, dict):
946
+ continue
947
+ role = sec.get("role", "unknown")
948
+ content = sec.get("content", "")
949
+ lines.append(f"\n[{role.upper()}]:")
950
+ lines.append(content)
951
+ lines.append("")
952
+
953
+ # Add optimized candidates
954
+ if optimized_candidates and isinstance(optimized_candidates, list):
955
+ lines.append("=" * 80)
956
+ lines.append(f"✨ TOP OPTIMIZED CANDIDATES ({len(optimized_candidates)})")
957
+ lines.append("=" * 80)
958
+ lines.append("")
959
+
960
+ for idx, cand in enumerate(optimized_candidates):
961
+ if not isinstance(cand, dict):
962
+ continue
963
+ candidate_score = cand.get("score") or {}
964
+ accuracy = candidate_score.get("accuracy", 0.0)
965
+ prompt_length = candidate_score.get("prompt_length", 0)
966
+ payload_kind = cand.get("payload_kind", "unknown")
967
+
968
+ # Try score.instance_scores first, then cand.instance_scores (explicit check)
969
+ instance_scores = (
970
+ candidate_score.get('instance_scores')
971
+ if 'instance_scores' in candidate_score
972
+ else cand.get('instance_scores')
973
+ )
974
+ n_eval = len(instance_scores) if instance_scores and isinstance(instance_scores, list) else 0
975
+
976
+ lines.append(f"[{idx+1}] Accuracy: {accuracy:.4f} | Length: {prompt_length} | Type: {payload_kind} | N: {n_eval}")
977
+ lines.append("-" * 80)
978
+
979
+ obj = cand.get("object")
980
+ if obj and isinstance(obj, dict) and payload_kind == "transformation":
981
+ # For transformations, text_replacements are nested in data
982
+ data_obj = obj.get("data", {})
983
+ replacement_lines = _format_text_replacements(data_obj)
984
+ lines.extend(replacement_lines)
985
+ lines.append("")
986
+
987
+ # Add all proposal candidates
988
+ if attempted_candidates and isinstance(attempted_candidates, list):
989
+ lines.append("=" * 80)
990
+ lines.append(f"💡 ALL PROPOSAL CANDIDATES ({len(attempted_candidates)})")
991
+ lines.append("=" * 80)
992
+ lines.append("")
993
+
994
+ for idx, cand in enumerate(attempted_candidates):
995
+ if not isinstance(cand, dict):
996
+ continue
997
+ accuracy = cand.get('accuracy', 0.0)
998
+ prompt_length = cand.get('prompt_length', 0)
999
+ tool_rate = cand.get('tool_call_rate', 0.0)
1000
+ instance_scores = cand.get('instance_scores', [])
1001
+ n_eval = len(instance_scores) if instance_scores else 0
1002
+
1003
+ lines.append(f"[{idx+1}] Accuracy: {accuracy:.4f} | Length: {prompt_length} | Tool Rate: {tool_rate:.2f} | N: {n_eval}")
1004
+ lines.append("-" * 80)
1005
+
1006
+ obj = cand.get("object")
1007
+ if obj and isinstance(obj, dict):
1008
+ # For proposals, text_replacements are at top level of object
1009
+ replacement_lines = _format_text_replacements(obj)
1010
+ lines.extend(replacement_lines)
1011
+ lines.append("")
1012
+
1013
+ lines.append("=" * 80)
1014
+ lines.append("END OF REPORT")
1015
+ lines.append("=" * 80)
1016
+
1017
+ # Determine save location
1018
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1019
+
1020
+ # Try to save in config directory first
1021
+ output_dir = config_path.parent / "results"
1022
+ output_dir.mkdir(exist_ok=True)
1023
+ output_file = output_dir / f"gepa_results_{job_id}_{timestamp}.txt"
1024
+
1025
+ with open(output_file, "w", encoding="utf-8") as f:
1026
+ f.write("\n".join(lines))
1027
+
1028
+ click.echo(f"\n📄 Results saved locally to: {output_file}")
1029
+
1030
+ except (PermissionError, OSError) as e:
1031
+ click.echo(f"⚠️ Could not save results file locally: {e}")
1032
+ except Exception as e:
1033
+ click.echo(f"⚠️ Unexpected error saving results file: {e}")
1034
+
1035
+
1036
+ def handle_prompt_learning(
1037
+ *,
1038
+ cfg_path: Path,
1039
+ backend_base: str,
1040
+ synth_key: str,
1041
+ task_url_override: str | None,
1042
+ allow_experimental: bool | None,
1043
+ dry_run: bool,
1044
+ poll: bool,
1045
+ poll_timeout: float,
1046
+ poll_interval: float,
1047
+ stream_format: str,
1048
+ ) -> None:
1049
+ """Handle prompt learning job creation (MIPRO or GEPA)."""
1050
+ import os
1051
+
1052
+ overrides: dict[str, Any] = {
1053
+ "backend": backend_base,
1054
+ }
1055
+
1056
+ build = build_prompt_learning_payload(
1057
+ config_path=cfg_path,
1058
+ task_url=None, # Force using TOML only
1059
+ overrides=overrides,
1060
+ allow_experimental=allow_experimental,
1061
+ )
1062
+
1063
+ env_key = os.environ.get("ENVIRONMENT_API_KEY")
1064
+ if not env_key:
1065
+ raise click.ClickException("ENVIRONMENT_API_KEY required for prompt learning flow")
1066
+
1067
+ click.echo("Performing task app health check…")
1068
+ health = check_task_app_health(build.task_url, env_key)
1069
+ if not health.ok:
1070
+ click.echo(f"Task app health check failed: {health.detail}")
1071
+ raise click.ClickException("Aborting due to failing health check")
1072
+ else:
1073
+ click.echo("Task app healthy")
1074
+
1075
+ create_url = f"{backend_base}/prompt-learning/online/jobs"
1076
+ headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
1077
+
1078
+ click.echo(f"POST {create_url}")
1079
+ click.echo("Payload preview:\n" + preview_json(build.payload, limit=800))
1080
+
1081
+ resp = http_post(create_url, headers=headers, json_body=build.payload)
1082
+ try:
1083
+ js = resp.json()
1084
+ except json.JSONDecodeError as e:
1085
+ click.echo(f"⚠️ Failed to parse JSON response: {e}")
1086
+ js = {"status": resp.status_code, "text": resp.text[:400]}
1087
+ click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
1088
+ if resp.status_code not in (200, 201):
1089
+ raise click.ClickException("Job creation failed")
1090
+ job_id = js.get("job_id") or js.get("id")
1091
+ if not job_id:
1092
+ raise click.ClickException("Response missing job id")
1093
+
1094
+ if not poll:
1095
+ click.echo(f"Created job {job_id} (polling disabled)")
1096
+ return
1097
+
1098
+ click.echo("\n=== Streaming Job Progress ===")
1099
+
1100
+ # Custom config for prompt learning to enable metrics
1101
+ if stream_format == "chart":
1102
+ config = StreamConfig(
1103
+ enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
1104
+ event_types={
1105
+ "prompt.learning.progress",
1106
+ "prompt.learning.gepa.start",
1107
+ "prompt.learning.gepa.complete",
1108
+ },
1109
+ metric_names={"gepa.transformation.mean_score"},
1110
+ )
1111
+ handlers = [LossCurveHandler()]
1112
+ click.echo("Using live loss chart (metric=gepa.transformation.mean_score)")
1113
+ else:
1114
+ # Enable metrics for CLI mode too
1115
+ config = StreamConfig(
1116
+ enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
1117
+ metric_names={"gepa.transformation.mean_score"},
1118
+ )
1119
+ handlers = [CLIHandler(
1120
+ hidden_event_types=_DEFAULT_PROMPT_LEARNING_HIDDEN_EVENTS,
1121
+ hidden_event_substrings=_DEFAULT_RL_HIDDEN_SUBSTRINGS,
1122
+ )]
1123
+
1124
+ streamer = JobStreamer(
1125
+ base_url=backend_base,
1126
+ api_key=synth_key,
1127
+ job_id=job_id,
1128
+ endpoints=StreamEndpoints.prompt_learning(job_id),
1129
+ config=config,
1130
+ handlers=handlers,
1131
+ interval_seconds=poll_interval,
1132
+ timeout_seconds=poll_timeout,
1133
+ )
1134
+ final_status = asyncio.run(streamer.stream_until_terminal())
1135
+ click.echo(f"Final status: {final_status.get('status', 'unknown')}")
1136
+ click.echo(preview_json(final_status, limit=600))
1137
+
1138
+ # Save results file locally
1139
+ _save_prompt_learning_results_locally(
1140
+ backend_base=backend_base,
1141
+ api_key=synth_key,
1142
+ job_id=job_id,
1143
+ config_path=cfg_path,
1144
+ )
770
1145
 
771
1146
 
772
1147
  def register(cli: click.Group) -> None:
@@ -18,7 +18,7 @@ _STATE_FILE = _STATE_DIR / "train_cli.json"
18
18
  @dataclass(slots=True)
19
19
  class ConfigCandidate:
20
20
  path: Path
21
- train_type: str # "rl", "sft", or "unknown"
21
+ train_type: str # "rl", "sft", "prompt_learning", or "unknown"
22
22
 
23
23
 
24
24
  def _load_last_config() -> Path | None:
@@ -94,6 +94,17 @@ def _iter_candidate_paths() -> Iterable[Path]:
94
94
 
95
95
 
96
96
  def _infer_config_type(data: dict) -> str:
97
+ # 0) Check for prompt_learning section (highest priority)
98
+ pl_section = data.get("prompt_learning")
99
+ if isinstance(pl_section, dict):
100
+ algorithm = pl_section.get("algorithm", "").lower()
101
+ if algorithm in {"mipro", "gepa"}:
102
+ return "prompt_learning"
103
+ # Also check if top-level has prompt_learning indicators
104
+ algorithm = data.get("algorithm")
105
+ if isinstance(algorithm, str) and algorithm.lower() in {"mipro", "gepa"}:
106
+ return "prompt_learning"
107
+
97
108
  # 1) Strong signals from [algorithm]
98
109
  algo = data.get("algorithm")
99
110
  if isinstance(algo, dict):
@@ -152,7 +163,7 @@ def discover_configs(explicit: list[str], *, requested_type: str | None) -> list
152
163
  cfg_type = _infer_config_type(data)
153
164
  if cfg_type == "unknown":
154
165
  raise click.ClickException(
155
- f"Config {path} is missing algorithm.type/method metadata. Add type = 'rl' or 'sft'."
166
+ f"Config {path} is missing algorithm.type/method metadata. Add type = 'rl', 'sft', or 'prompt_learning'."
156
167
  )
157
168
  candidates.append(ConfigCandidate(path=path, train_type=cfg_type))
158
169
  seen.add(path)
@@ -1,5 +1,13 @@
1
- """Typed training config loaders for RL and SFT jobs."""
1
+ """Typed training config loaders for RL, SFT, and Prompt Learning jobs."""
2
2
 
3
+ from .prompt_learning import (
4
+ GEPAConfig,
5
+ MessagePatternConfig,
6
+ MIPROConfig,
7
+ PromptLearningConfig,
8
+ PromptLearningPolicyConfig,
9
+ PromptPatternConfig,
10
+ )
3
11
  from .rl import (
4
12
  EvaluationConfig,
5
13
  JudgeConfig,
@@ -28,14 +36,20 @@ __all__ = [
28
36
  "AlgorithmConfig",
29
37
  "ComputeConfig",
30
38
  "EvaluationConfig",
39
+ "GEPAConfig",
31
40
  "HyperparametersConfig",
32
41
  "HyperparametersParallelism",
33
42
  "JobConfig",
34
43
  "JudgeConfig",
35
44
  "JudgeOptionsConfig",
36
45
  "LoraConfig",
46
+ "MIPROConfig",
47
+ "MessagePatternConfig",
37
48
  "ModelConfig",
38
49
  "PolicyConfig",
50
+ "PromptLearningConfig",
51
+ "PromptLearningPolicyConfig",
52
+ "PromptPatternConfig",
39
53
  "RewardsConfig",
40
54
  "RLConfig",
41
55
  "RLServicesConfig",