synth-ai 0.2.16__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +2 -2
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +27 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +43 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +91 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/verilog_rl_lora.toml +80 -123
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
- examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
- examples/qwen_coder/configs/coder_lora_small.toml +1 -3
- examples/qwen_vl/README.md +10 -12
- examples/qwen_vl/SETUP_COMPLETE.md +7 -8
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
- examples/qwen_vl/collect_data_via_cli.md +76 -84
- examples/qwen_vl/collect_vision_traces.py +4 -4
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
- examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
- examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
- examples/qwen_vl/run_vision_comparison.sh +6 -7
- examples/rl/README.md +5 -5
- examples/rl/configs/rl_from_base_qwen.toml +26 -1
- examples/rl/configs/rl_from_base_qwen17.toml +6 -2
- examples/rl/task_app/README.md +1 -2
- examples/rl/task_app/math_single_step.py +2 -2
- examples/run_crafter_demo.sh +2 -2
- examples/sft/README.md +1 -1
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
- examples/swe/task_app/README.md +32 -2
- examples/swe/task_app/grpo_swe_mini.py +4 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
- examples/swe/task_app/hosted/inference/openai_client.py +4 -38
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/README.md +1 -1
- examples/task_apps/crafter/task_app/grpo_crafter.py +90 -5
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +372 -107
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +81 -12
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +82 -11
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/math/README.md +1 -2
- examples/task_apps/pokemon_red/README.md +3 -4
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
- examples/task_apps/pokemon_red/task_app.py +288 -39
- examples/task_apps/sokoban/README.md +2 -3
- examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
- examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +3 -2
- examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +1 -1
- examples/warming_up_to_rl/task_app/grpo_crafter.py +185 -5
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +156 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +37 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +6 -0
- synth_ai/api/train/builders.py +99 -4
- synth_ai/api/train/cli.py +516 -26
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +23 -2
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +61 -7
- synth_ai/api/train/configs/sft.py +6 -2
- synth_ai/api/train/configs/shared.py +59 -2
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/auth/credentials.py +119 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +94 -18
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +18 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1112 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +424 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +177 -0
- synth_ai/cli/commands/help/core.py +72 -0
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/__init__.py +64 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +200 -0
- synth_ai/cli/commands/train/judge_validation.py +305 -0
- synth_ai/cli/commands/train/validation.py +386 -0
- synth_ai/cli/demo.py +30 -158
- synth_ai/cli/deploy/__init__.py +43 -0
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +20 -265
- synth_ai/cli/status.py +7 -126
- synth_ai/cli/task_app_deploy.py +1 -10
- synth_ai/cli/task_app_modal_serve.py +4 -9
- synth_ai/cli/task_app_serve.py +4 -11
- synth_ai/cli/task_apps.py +51 -1480
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +1 -14
- synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/environments/examples/red/engine.py +33 -12
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/environment.py +26 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/http.py +12 -0
- synth_ai/judge_schemas.py +10 -10
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/learning/rl/client.py +3 -1
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +518 -0
- synth_ai/streaming/streamer.py +320 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +45 -9
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +40 -33
- synth_ai/utils/http.py +4 -1
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +285 -3
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/METADATA +109 -6
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/RECORD +291 -142
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
- synth_ai/cli/tui.py +0 -62
- synth_ai/tui/__init__.py +0 -5
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -911
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
synth_ai/api/train/cli.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextlib
|
|
3
5
|
import importlib
|
|
6
|
+
import json
|
|
4
7
|
import os
|
|
5
8
|
import time
|
|
6
9
|
from collections.abc import Callable, Mapping
|
|
@@ -17,10 +20,18 @@ try:
|
|
|
17
20
|
except Exception as exc: # pragma: no cover - critical dependency
|
|
18
21
|
raise RuntimeError("Unable to load backend configuration helpers") from exc
|
|
19
22
|
|
|
20
|
-
from .
|
|
23
|
+
from synth_ai.streaming import (
|
|
24
|
+
CLIHandler,
|
|
25
|
+
JobStreamer,
|
|
26
|
+
LossCurveHandler,
|
|
27
|
+
StreamConfig,
|
|
28
|
+
StreamEndpoints,
|
|
29
|
+
StreamType,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
from .builders import build_prompt_learning_payload, build_rl_payload, build_sft_payload
|
|
21
33
|
from .config_finder import discover_configs, prompt_for_config
|
|
22
34
|
from .env_resolver import KeySpec, resolve_env
|
|
23
|
-
from .pollers import RLJobPoller, SFTJobPoller
|
|
24
35
|
from .task_app import check_task_app_health
|
|
25
36
|
from .utils import (
|
|
26
37
|
REPO_ROOT,
|
|
@@ -36,6 +47,45 @@ from .utils import (
|
|
|
36
47
|
validate_sft_jsonl,
|
|
37
48
|
)
|
|
38
49
|
|
|
50
|
+
# Constants for prompt learning event types
|
|
51
|
+
_PROMPT_LEARNING_EVENT_BEST_PROMPT = "prompt.learning.best.prompt"
|
|
52
|
+
_PROMPT_LEARNING_EVENT_FINAL_RESULTS = "prompt.learning.final.results"
|
|
53
|
+
_PROMPT_LEARNING_EVENT_VALIDATION_SCORED = "prompt.learning.validation.scored"
|
|
54
|
+
_PROMPT_LEARNING_EVENT_GEPA_COMPLETE = "prompt.learning.gepa.complete"
|
|
55
|
+
|
|
56
|
+
# Constants for formatting
|
|
57
|
+
_MAX_TEXT_REPLACEMENTS_DISPLAY = 3 # Max number of text replacements to show in output
|
|
58
|
+
_RESULTS_FILE_MAX_EVENTS = 10000 # Max events to fetch for results file generation
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _format_text_replacements(obj: dict[str, Any] | None, max_display: int = _MAX_TEXT_REPLACEMENTS_DISPLAY) -> list[str]:
|
|
62
|
+
"""Extract and format text replacements from a candidate object.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
obj: Candidate object dictionary containing text_replacements
|
|
66
|
+
max_display: Maximum number of replacements to display
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
List of formatted lines showing role and replacement text
|
|
70
|
+
"""
|
|
71
|
+
lines = []
|
|
72
|
+
if not obj or not isinstance(obj, dict):
|
|
73
|
+
return lines
|
|
74
|
+
|
|
75
|
+
text_replacements = obj.get("text_replacements", [])
|
|
76
|
+
if not text_replacements or not isinstance(text_replacements, list):
|
|
77
|
+
return lines
|
|
78
|
+
|
|
79
|
+
for replacement in text_replacements[:max_display]:
|
|
80
|
+
if isinstance(replacement, dict):
|
|
81
|
+
new_text = replacement.get("new_text", "")
|
|
82
|
+
role = replacement.get("apply_to_role", "system")
|
|
83
|
+
if new_text:
|
|
84
|
+
lines.append(f" [{role.upper()}]: {new_text}")
|
|
85
|
+
lines.append("")
|
|
86
|
+
|
|
87
|
+
return lines
|
|
88
|
+
|
|
39
89
|
|
|
40
90
|
def _discover_dataset_candidates(
|
|
41
91
|
config_path: Path, limit: int = 50, timeout: float = 10.0
|
|
@@ -135,6 +185,66 @@ def _default_backend() -> str:
|
|
|
135
185
|
return f"{base}/api" if not base.endswith("/api") else base
|
|
136
186
|
|
|
137
187
|
|
|
188
|
+
_DEFAULT_SFT_HIDDEN_EVENTS = {
|
|
189
|
+
"sft.created",
|
|
190
|
+
"sft.pricing.check.requested",
|
|
191
|
+
"sft.pricing.check.allowed",
|
|
192
|
+
"sft.stage",
|
|
193
|
+
"snapshot.fetch",
|
|
194
|
+
"hatchet.preflight",
|
|
195
|
+
"hatchet.submission.attempt",
|
|
196
|
+
"hatchet.submission.result",
|
|
197
|
+
"sft.running",
|
|
198
|
+
"sft.status",
|
|
199
|
+
"sft.worker.alive",
|
|
200
|
+
"sft.dispatch.selected",
|
|
201
|
+
"sft.config.prepared",
|
|
202
|
+
"sft.strategy.selected",
|
|
203
|
+
"sft.training.args",
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
_DEFAULT_RL_HIDDEN_SUBSTRINGS = {"modal", "hatchet"}
|
|
207
|
+
|
|
208
|
+
_DEFAULT_PROMPT_LEARNING_HIDDEN_EVENTS = {
|
|
209
|
+
"prompt.learning.policy.tokens",
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _build_stream_components(
|
|
214
|
+
stream_format: str,
|
|
215
|
+
*,
|
|
216
|
+
hidden_event_types: set[str] | None = None,
|
|
217
|
+
hidden_event_substrings: set[str] | None = None,
|
|
218
|
+
) -> tuple[StreamConfig, list]:
|
|
219
|
+
"""Return stream configuration and handlers for the requested format."""
|
|
220
|
+
if stream_format == "chart":
|
|
221
|
+
config = StreamConfig(
|
|
222
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
|
|
223
|
+
event_types={
|
|
224
|
+
"sft.progress",
|
|
225
|
+
"sft.training.started",
|
|
226
|
+
"sft.training.finish",
|
|
227
|
+
"sft.validation.summary",
|
|
228
|
+
"rl.train.step",
|
|
229
|
+
"rl.train.started",
|
|
230
|
+
"rl.train.completed",
|
|
231
|
+
"workflow.completed",
|
|
232
|
+
"workflow.failed",
|
|
233
|
+
},
|
|
234
|
+
metric_names={"train.loss"},
|
|
235
|
+
)
|
|
236
|
+
handlers = [LossCurveHandler()]
|
|
237
|
+
else:
|
|
238
|
+
config = StreamConfig.default()
|
|
239
|
+
handlers = [
|
|
240
|
+
CLIHandler(
|
|
241
|
+
hidden_event_types=hidden_event_types or set(),
|
|
242
|
+
hidden_event_substrings=hidden_event_substrings or set(),
|
|
243
|
+
)
|
|
244
|
+
]
|
|
245
|
+
return config, handlers
|
|
246
|
+
|
|
247
|
+
|
|
138
248
|
@click.command("train")
|
|
139
249
|
@click.option(
|
|
140
250
|
"--config",
|
|
@@ -143,7 +253,7 @@ def _default_backend() -> str:
|
|
|
143
253
|
type=click.Path(),
|
|
144
254
|
help="Path to training TOML (repeatable)",
|
|
145
255
|
)
|
|
146
|
-
@click.option("--type", "train_type", type=click.Choice(["auto", "rl", "sft"]), default="auto")
|
|
256
|
+
@click.option("--type", "train_type", type=click.Choice(["auto", "rl", "sft", "prompt_learning"]), default="auto")
|
|
147
257
|
@click.option(
|
|
148
258
|
"--env-file",
|
|
149
259
|
"env_files",
|
|
@@ -183,6 +293,13 @@ def _default_backend() -> str:
|
|
|
183
293
|
"--poll-timeout", default=3600.0, type=float, help="Maximum seconds to poll before timing out"
|
|
184
294
|
)
|
|
185
295
|
@click.option("--poll-interval", default=5.0, type=float, help="Seconds between poll attempts")
|
|
296
|
+
@click.option(
|
|
297
|
+
"--stream-format",
|
|
298
|
+
type=click.Choice(["cli", "chart"]),
|
|
299
|
+
default="cli",
|
|
300
|
+
show_default=True,
|
|
301
|
+
help="Streaming output style (cli = line updates, chart = live loss panel)",
|
|
302
|
+
)
|
|
186
303
|
@click.option(
|
|
187
304
|
"--examples",
|
|
188
305
|
"examples_limit",
|
|
@@ -204,9 +321,10 @@ def train_command(
|
|
|
204
321
|
poll: bool,
|
|
205
322
|
poll_timeout: float,
|
|
206
323
|
poll_interval: float,
|
|
324
|
+
stream_format: str,
|
|
207
325
|
examples_limit: int | None,
|
|
208
326
|
) -> None:
|
|
209
|
-
"""Interactive launcher for RL / SFT jobs."""
|
|
327
|
+
"""Interactive launcher for RL / SFT / Prompt Learning jobs."""
|
|
210
328
|
|
|
211
329
|
candidates = discover_configs(
|
|
212
330
|
list(config_paths), requested_type=train_type if train_type != "auto" else None
|
|
@@ -218,16 +336,16 @@ def train_command(
|
|
|
218
336
|
)
|
|
219
337
|
|
|
220
338
|
effective_type = train_type if train_type != "auto" else selection.train_type
|
|
221
|
-
if effective_type not in {"rl", "sft"}:
|
|
339
|
+
if effective_type not in {"rl", "sft", "prompt_learning"}:
|
|
222
340
|
effective_type = click.prompt(
|
|
223
|
-
"Detected config type is ambiguous. Enter type", type=click.Choice(["rl", "sft"])
|
|
341
|
+
"Detected config type is ambiguous. Enter type", type=click.Choice(["rl", "sft", "prompt_learning"])
|
|
224
342
|
)
|
|
225
343
|
|
|
226
344
|
cfg_path = selection.path
|
|
227
345
|
click.echo(f"Using config: {cfg_path} ({effective_type})")
|
|
228
346
|
|
|
229
347
|
required_keys: list[KeySpec] = []
|
|
230
|
-
if effective_type == "rl":
|
|
348
|
+
if effective_type == "rl" or effective_type == "prompt_learning":
|
|
231
349
|
required_keys.append(KeySpec("SYNTH_API_KEY", "Synth API key for backend"))
|
|
232
350
|
required_keys.append(
|
|
233
351
|
KeySpec(
|
|
@@ -302,6 +420,20 @@ def train_command(
|
|
|
302
420
|
poll=poll,
|
|
303
421
|
poll_timeout=poll_timeout,
|
|
304
422
|
poll_interval=poll_interval,
|
|
423
|
+
stream_format=stream_format,
|
|
424
|
+
)
|
|
425
|
+
elif effective_type == "prompt_learning":
|
|
426
|
+
handle_prompt_learning(
|
|
427
|
+
cfg_path=cfg_path,
|
|
428
|
+
backend_base=backend_base,
|
|
429
|
+
synth_key=synth_key,
|
|
430
|
+
task_url_override=task_url,
|
|
431
|
+
allow_experimental=allow_experimental,
|
|
432
|
+
dry_run=dry_run,
|
|
433
|
+
poll=poll,
|
|
434
|
+
poll_timeout=poll_timeout,
|
|
435
|
+
poll_interval=poll_interval,
|
|
436
|
+
stream_format=stream_format,
|
|
305
437
|
)
|
|
306
438
|
else:
|
|
307
439
|
dataset_override_path = Path(dataset_path).expanduser().resolve() if dataset_path else None
|
|
@@ -315,13 +447,22 @@ def train_command(
|
|
|
315
447
|
poll=poll,
|
|
316
448
|
poll_timeout=poll_timeout,
|
|
317
449
|
poll_interval=poll_interval,
|
|
450
|
+
stream_format=stream_format,
|
|
318
451
|
examples_limit=examples_limit,
|
|
319
452
|
)
|
|
320
453
|
|
|
321
454
|
|
|
322
455
|
def _wait_for_training_file(
|
|
323
|
-
backend_base: str, api_key: str, file_id: str, *, timeout: float =
|
|
456
|
+
backend_base: str, api_key: str, file_id: str, *, timeout: float = 10.0
|
|
324
457
|
) -> None:
|
|
458
|
+
"""Wait for training file to be visible after upload.
|
|
459
|
+
|
|
460
|
+
Reduced from 120s to 10s because:
|
|
461
|
+
- POST response already confirms file is uploaded
|
|
462
|
+
- Backend now forces read-your-writes consistency
|
|
463
|
+
- By job creation time, replica lag has resolved
|
|
464
|
+
- Quick sanity check only, not critical path
|
|
465
|
+
"""
|
|
325
466
|
url = f"{backend_base.rstrip('/')}/files/{file_id}"
|
|
326
467
|
headers = {"Authorization": f"Bearer {api_key}"}
|
|
327
468
|
elapsed = 0.0
|
|
@@ -332,7 +473,7 @@ def _wait_for_training_file(
|
|
|
332
473
|
if resp.status_code == 200:
|
|
333
474
|
try:
|
|
334
475
|
data = resp.json()
|
|
335
|
-
except
|
|
476
|
+
except json.JSONDecodeError:
|
|
336
477
|
data = {}
|
|
337
478
|
status = str(
|
|
338
479
|
data.get("status") or data.get("state") or data.get("storage_state") or "ready"
|
|
@@ -357,7 +498,7 @@ def _wait_for_training_file(
|
|
|
357
498
|
# Auth errors won't resolve by polling - fail immediately
|
|
358
499
|
try:
|
|
359
500
|
error_body = resp.json()
|
|
360
|
-
except
|
|
501
|
+
except json.JSONDecodeError:
|
|
361
502
|
error_body = resp.text[:400]
|
|
362
503
|
click.echo("\n[ERROR] Authentication failed when checking training file:")
|
|
363
504
|
click.echo(f" URL: {url}")
|
|
@@ -372,7 +513,7 @@ def _wait_for_training_file(
|
|
|
372
513
|
# Other errors - show details but keep polling
|
|
373
514
|
try:
|
|
374
515
|
error_body = resp.json()
|
|
375
|
-
except
|
|
516
|
+
except json.JSONDecodeError:
|
|
376
517
|
error_body = resp.text[:400]
|
|
377
518
|
click.echo(f"[WARN] Unexpected response checking file {file_id}:")
|
|
378
519
|
click.echo(f" URL: {url}")
|
|
@@ -400,6 +541,7 @@ def handle_rl(
|
|
|
400
541
|
poll: bool,
|
|
401
542
|
poll_timeout: float,
|
|
402
543
|
poll_interval: float,
|
|
544
|
+
stream_format: str,
|
|
403
545
|
) -> None:
|
|
404
546
|
overrides: dict[str, Any] = {
|
|
405
547
|
"backend": backend_base,
|
|
@@ -423,7 +565,7 @@ def handle_rl(
|
|
|
423
565
|
)
|
|
424
566
|
try:
|
|
425
567
|
parsed_json = vresp.json()
|
|
426
|
-
except
|
|
568
|
+
except json.JSONDecodeError:
|
|
427
569
|
parsed_json = None
|
|
428
570
|
|
|
429
571
|
if isinstance(parsed_json, Mapping):
|
|
@@ -458,8 +600,9 @@ def handle_rl(
|
|
|
458
600
|
)
|
|
459
601
|
statuses = [attempt.get("status") for attempt in attempts]
|
|
460
602
|
click.echo(f"Verification OK (candidates={cands}, statuses={statuses})")
|
|
461
|
-
except
|
|
462
|
-
|
|
603
|
+
except (KeyError, ValueError, AttributeError):
|
|
604
|
+
# Parsing verification summary failed, but verification itself succeeded
|
|
605
|
+
click.echo("Verification OK")
|
|
463
606
|
|
|
464
607
|
env_key = os.environ.get("ENVIRONMENT_API_KEY")
|
|
465
608
|
if not env_key:
|
|
@@ -484,7 +627,8 @@ def handle_rl(
|
|
|
484
627
|
resp = http_post(create_url, headers=headers, json_body=build.payload)
|
|
485
628
|
try:
|
|
486
629
|
js = resp.json()
|
|
487
|
-
except
|
|
630
|
+
except json.JSONDecodeError as e:
|
|
631
|
+
click.echo(f"⚠️ Failed to parse JSON response: {e}")
|
|
488
632
|
js = {"status": resp.status_code, "text": resp.text[:400]}
|
|
489
633
|
click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
|
|
490
634
|
if resp.status_code not in (200, 201):
|
|
@@ -497,10 +641,41 @@ def handle_rl(
|
|
|
497
641
|
click.echo(f"Created job {job_id} (polling disabled)")
|
|
498
642
|
return
|
|
499
643
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
644
|
+
click.echo("\n=== Streaming Job Progress ===")
|
|
645
|
+
|
|
646
|
+
# Enable metrics for prompt learning
|
|
647
|
+
if stream_format == "chart":
|
|
648
|
+
config = StreamConfig(
|
|
649
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
|
|
650
|
+
event_types={
|
|
651
|
+
"prompt.learning.progress",
|
|
652
|
+
"prompt.learning.gepa.start",
|
|
653
|
+
"prompt.learning.gepa.complete",
|
|
654
|
+
},
|
|
655
|
+
metric_names={"gepa.transformation.mean_score"},
|
|
656
|
+
)
|
|
657
|
+
handlers = [LossCurveHandler()]
|
|
658
|
+
click.echo("Using live chart (metric=gepa.transformation.mean_score)")
|
|
659
|
+
else:
|
|
660
|
+
config = StreamConfig(
|
|
661
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
|
|
662
|
+
metric_names={"gepa.transformation.mean_score"},
|
|
663
|
+
)
|
|
664
|
+
handlers = [CLIHandler(hidden_event_substrings=_DEFAULT_RL_HIDDEN_SUBSTRINGS)]
|
|
665
|
+
|
|
666
|
+
streamer = JobStreamer(
|
|
667
|
+
base_url=backend_base,
|
|
668
|
+
api_key=synth_key,
|
|
669
|
+
job_id=job_id,
|
|
670
|
+
endpoints=StreamEndpoints.rl(job_id),
|
|
671
|
+
config=config,
|
|
672
|
+
handlers=handlers,
|
|
673
|
+
interval_seconds=poll_interval,
|
|
674
|
+
timeout_seconds=poll_timeout,
|
|
675
|
+
)
|
|
676
|
+
final_status = asyncio.run(streamer.stream_until_terminal())
|
|
677
|
+
click.echo(f"Final status: {final_status.get('status', 'unknown')}")
|
|
678
|
+
click.echo(preview_json(final_status, limit=600))
|
|
504
679
|
|
|
505
680
|
|
|
506
681
|
def handle_sft(
|
|
@@ -514,6 +689,7 @@ def handle_sft(
|
|
|
514
689
|
poll: bool,
|
|
515
690
|
poll_timeout: float,
|
|
516
691
|
poll_interval: float,
|
|
692
|
+
stream_format: str,
|
|
517
693
|
examples_limit: int | None,
|
|
518
694
|
) -> None:
|
|
519
695
|
dataset_path = dataset_override
|
|
@@ -641,17 +817,331 @@ def handle_sft(
|
|
|
641
817
|
click.echo(f"Started job {job_id} (polling disabled)")
|
|
642
818
|
return
|
|
643
819
|
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
820
|
+
click.echo("\n=== Streaming Job Progress ===")
|
|
821
|
+
config, handlers = _build_stream_components(
|
|
822
|
+
stream_format, hidden_event_types=_DEFAULT_SFT_HIDDEN_EVENTS
|
|
823
|
+
)
|
|
824
|
+
if stream_format == "chart":
|
|
825
|
+
click.echo("Using live loss chart (metric=train.loss)")
|
|
826
|
+
streamer = JobStreamer(
|
|
827
|
+
base_url=backend_base,
|
|
828
|
+
api_key=synth_key,
|
|
829
|
+
job_id=job_id,
|
|
830
|
+
endpoints=StreamEndpoints.learning(job_id),
|
|
831
|
+
config=config,
|
|
832
|
+
handlers=handlers,
|
|
833
|
+
interval_seconds=poll_interval,
|
|
834
|
+
timeout_seconds=poll_timeout,
|
|
835
|
+
)
|
|
836
|
+
final_status = asyncio.run(streamer.stream_until_terminal())
|
|
837
|
+
status = final_status.get('status') if isinstance(final_status, dict) else 'unknown'
|
|
838
|
+
click.echo(f"Final status: {status}")
|
|
839
|
+
click.echo(preview_json(final_status, limit=600))
|
|
648
840
|
finally:
|
|
649
841
|
if limited_path is not None:
|
|
650
|
-
|
|
842
|
+
with contextlib.suppress(OSError):
|
|
651
843
|
limited_path.unlink(missing_ok=True)
|
|
844
|
+
# Clean up empty parent directory if possible
|
|
845
|
+
with contextlib.suppress(OSError):
|
|
652
846
|
limited_path.parent.rmdir()
|
|
653
|
-
|
|
654
|
-
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def _save_prompt_learning_results_locally(
|
|
850
|
+
*,
|
|
851
|
+
backend_base: str,
|
|
852
|
+
api_key: str,
|
|
853
|
+
job_id: str,
|
|
854
|
+
config_path: Path,
|
|
855
|
+
) -> None:
|
|
856
|
+
"""Fetch events and generate results file locally after prompt learning completes."""
|
|
857
|
+
from datetime import datetime
|
|
858
|
+
|
|
859
|
+
try:
|
|
860
|
+
# Fetch all events
|
|
861
|
+
url = f"{backend_base}/prompt-learning/online/jobs/{job_id}/events?limit={_RESULTS_FILE_MAX_EVENTS}"
|
|
862
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
863
|
+
resp = http_get(url, headers=headers, timeout=30.0)
|
|
864
|
+
|
|
865
|
+
if resp.status_code != 200:
|
|
866
|
+
click.echo(f"⚠️ Could not fetch events to generate results file (status={resp.status_code})")
|
|
867
|
+
return
|
|
868
|
+
|
|
869
|
+
data = resp.json()
|
|
870
|
+
# Validate response structure
|
|
871
|
+
if not isinstance(data, dict):
|
|
872
|
+
click.echo(f"⚠️ Unexpected response type: {type(data).__name__}")
|
|
873
|
+
return
|
|
874
|
+
|
|
875
|
+
events = data.get("events", [])
|
|
876
|
+
if not isinstance(events, list):
|
|
877
|
+
click.echo(f"⚠️ Events field is not a list: {type(events).__name__}")
|
|
878
|
+
return
|
|
879
|
+
|
|
880
|
+
if not events:
|
|
881
|
+
return
|
|
882
|
+
|
|
883
|
+
# Extract key data from events
|
|
884
|
+
best_score = None
|
|
885
|
+
best_prompt = None
|
|
886
|
+
baseline_score = None
|
|
887
|
+
attempted_candidates = []
|
|
888
|
+
optimized_candidates = []
|
|
889
|
+
|
|
890
|
+
for event in events:
|
|
891
|
+
if not isinstance(event, dict):
|
|
892
|
+
continue # Skip malformed events
|
|
893
|
+
|
|
894
|
+
event_type = event.get("type", "")
|
|
895
|
+
event_data = event.get("data", {})
|
|
896
|
+
if not isinstance(event_data, dict):
|
|
897
|
+
event_data = {} # Fallback to empty dict for safety
|
|
898
|
+
|
|
899
|
+
if event_type == _PROMPT_LEARNING_EVENT_BEST_PROMPT:
|
|
900
|
+
best_score = event_data.get("best_score")
|
|
901
|
+
best_prompt = event_data.get("best_prompt")
|
|
902
|
+
elif event_type == _PROMPT_LEARNING_EVENT_FINAL_RESULTS:
|
|
903
|
+
attempted_candidates = event_data.get("attempted_candidates", [])
|
|
904
|
+
optimized_candidates = event_data.get("optimized_candidates", [])
|
|
905
|
+
elif event_type == _PROMPT_LEARNING_EVENT_VALIDATION_SCORED:
|
|
906
|
+
# Check if this is the baseline by checking for is_baseline flag or baseline in message
|
|
907
|
+
is_baseline = event_data.get("is_baseline", False)
|
|
908
|
+
if not is_baseline:
|
|
909
|
+
msg = event.get("message", "")
|
|
910
|
+
is_baseline = "baseline" in msg.lower()
|
|
911
|
+
if is_baseline:
|
|
912
|
+
baseline_score = event_data.get("accuracy")
|
|
913
|
+
elif event_type == _PROMPT_LEARNING_EVENT_GEPA_COMPLETE and best_score is None:
|
|
914
|
+
best_score = event_data.get("best_score")
|
|
915
|
+
|
|
916
|
+
if not (attempted_candidates or optimized_candidates):
|
|
917
|
+
return
|
|
918
|
+
|
|
919
|
+
# Generate formatted report
|
|
920
|
+
lines = []
|
|
921
|
+
lines.append("=" * 80)
|
|
922
|
+
lines.append("GEPA PROMPT LEARNING RESULTS")
|
|
923
|
+
lines.append("=" * 80)
|
|
924
|
+
lines.append(f"Job ID: {job_id}")
|
|
925
|
+
lines.append(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
926
|
+
lines.append("")
|
|
927
|
+
if baseline_score is not None:
|
|
928
|
+
lines.append(f"📊 Baseline Score: {baseline_score:.4f} ({baseline_score*100:.1f}%)")
|
|
929
|
+
if best_score is not None:
|
|
930
|
+
lines.append(f"🏆 Best Score: {best_score:.4f} ({best_score*100:.1f}%)")
|
|
931
|
+
if baseline_score is not None and best_score is not None:
|
|
932
|
+
improvement = ((best_score - baseline_score) / baseline_score) * 100 if baseline_score > 0 else 0
|
|
933
|
+
lines.append(f"📈 Improvement: {improvement:+.1f}% relative ({(best_score - baseline_score)*100:+.1f} pp absolute)")
|
|
934
|
+
lines.append("=" * 80)
|
|
935
|
+
lines.append("")
|
|
936
|
+
|
|
937
|
+
# Add best prompt if available
|
|
938
|
+
if best_prompt and isinstance(best_prompt, dict):
|
|
939
|
+
lines.append("🏆 BEST PROMPT")
|
|
940
|
+
lines.append("-" * 80)
|
|
941
|
+
sections = best_prompt.get("sections", [])
|
|
942
|
+
if not isinstance(sections, list):
|
|
943
|
+
sections = []
|
|
944
|
+
for sec in sections:
|
|
945
|
+
if not isinstance(sec, dict):
|
|
946
|
+
continue
|
|
947
|
+
role = sec.get("role", "unknown")
|
|
948
|
+
content = sec.get("content", "")
|
|
949
|
+
lines.append(f"\n[{role.upper()}]:")
|
|
950
|
+
lines.append(content)
|
|
951
|
+
lines.append("")
|
|
952
|
+
|
|
953
|
+
# Add optimized candidates
|
|
954
|
+
if optimized_candidates and isinstance(optimized_candidates, list):
|
|
955
|
+
lines.append("=" * 80)
|
|
956
|
+
lines.append(f"✨ TOP OPTIMIZED CANDIDATES ({len(optimized_candidates)})")
|
|
957
|
+
lines.append("=" * 80)
|
|
958
|
+
lines.append("")
|
|
959
|
+
|
|
960
|
+
for idx, cand in enumerate(optimized_candidates):
|
|
961
|
+
if not isinstance(cand, dict):
|
|
962
|
+
continue
|
|
963
|
+
candidate_score = cand.get("score") or {}
|
|
964
|
+
accuracy = candidate_score.get("accuracy", 0.0)
|
|
965
|
+
prompt_length = candidate_score.get("prompt_length", 0)
|
|
966
|
+
payload_kind = cand.get("payload_kind", "unknown")
|
|
967
|
+
|
|
968
|
+
# Try score.instance_scores first, then cand.instance_scores (explicit check)
|
|
969
|
+
instance_scores = (
|
|
970
|
+
candidate_score.get('instance_scores')
|
|
971
|
+
if 'instance_scores' in candidate_score
|
|
972
|
+
else cand.get('instance_scores')
|
|
973
|
+
)
|
|
974
|
+
n_eval = len(instance_scores) if instance_scores and isinstance(instance_scores, list) else 0
|
|
975
|
+
|
|
976
|
+
lines.append(f"[{idx+1}] Accuracy: {accuracy:.4f} | Length: {prompt_length} | Type: {payload_kind} | N: {n_eval}")
|
|
977
|
+
lines.append("-" * 80)
|
|
978
|
+
|
|
979
|
+
obj = cand.get("object")
|
|
980
|
+
if obj and isinstance(obj, dict) and payload_kind == "transformation":
|
|
981
|
+
# For transformations, text_replacements are nested in data
|
|
982
|
+
data_obj = obj.get("data", {})
|
|
983
|
+
replacement_lines = _format_text_replacements(data_obj)
|
|
984
|
+
lines.extend(replacement_lines)
|
|
985
|
+
lines.append("")
|
|
986
|
+
|
|
987
|
+
# Add all proposal candidates
|
|
988
|
+
if attempted_candidates and isinstance(attempted_candidates, list):
|
|
989
|
+
lines.append("=" * 80)
|
|
990
|
+
lines.append(f"💡 ALL PROPOSAL CANDIDATES ({len(attempted_candidates)})")
|
|
991
|
+
lines.append("=" * 80)
|
|
992
|
+
lines.append("")
|
|
993
|
+
|
|
994
|
+
for idx, cand in enumerate(attempted_candidates):
|
|
995
|
+
if not isinstance(cand, dict):
|
|
996
|
+
continue
|
|
997
|
+
accuracy = cand.get('accuracy', 0.0)
|
|
998
|
+
prompt_length = cand.get('prompt_length', 0)
|
|
999
|
+
tool_rate = cand.get('tool_call_rate', 0.0)
|
|
1000
|
+
instance_scores = cand.get('instance_scores', [])
|
|
1001
|
+
n_eval = len(instance_scores) if instance_scores else 0
|
|
1002
|
+
|
|
1003
|
+
lines.append(f"[{idx+1}] Accuracy: {accuracy:.4f} | Length: {prompt_length} | Tool Rate: {tool_rate:.2f} | N: {n_eval}")
|
|
1004
|
+
lines.append("-" * 80)
|
|
1005
|
+
|
|
1006
|
+
obj = cand.get("object")
|
|
1007
|
+
if obj and isinstance(obj, dict):
|
|
1008
|
+
# For proposals, text_replacements are at top level of object
|
|
1009
|
+
replacement_lines = _format_text_replacements(obj)
|
|
1010
|
+
lines.extend(replacement_lines)
|
|
1011
|
+
lines.append("")
|
|
1012
|
+
|
|
1013
|
+
lines.append("=" * 80)
|
|
1014
|
+
lines.append("END OF REPORT")
|
|
1015
|
+
lines.append("=" * 80)
|
|
1016
|
+
|
|
1017
|
+
# Determine save location
|
|
1018
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1019
|
+
|
|
1020
|
+
# Try to save in config directory first
|
|
1021
|
+
output_dir = config_path.parent / "results"
|
|
1022
|
+
output_dir.mkdir(exist_ok=True)
|
|
1023
|
+
output_file = output_dir / f"gepa_results_{job_id}_{timestamp}.txt"
|
|
1024
|
+
|
|
1025
|
+
with open(output_file, "w", encoding="utf-8") as f:
|
|
1026
|
+
f.write("\n".join(lines))
|
|
1027
|
+
|
|
1028
|
+
click.echo(f"\n📄 Results saved locally to: {output_file}")
|
|
1029
|
+
|
|
1030
|
+
except (PermissionError, OSError) as e:
|
|
1031
|
+
click.echo(f"⚠️ Could not save results file locally: {e}")
|
|
1032
|
+
except Exception as e:
|
|
1033
|
+
click.echo(f"⚠️ Unexpected error saving results file: {e}")
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
def handle_prompt_learning(
|
|
1037
|
+
*,
|
|
1038
|
+
cfg_path: Path,
|
|
1039
|
+
backend_base: str,
|
|
1040
|
+
synth_key: str,
|
|
1041
|
+
task_url_override: str | None,
|
|
1042
|
+
allow_experimental: bool | None,
|
|
1043
|
+
dry_run: bool,
|
|
1044
|
+
poll: bool,
|
|
1045
|
+
poll_timeout: float,
|
|
1046
|
+
poll_interval: float,
|
|
1047
|
+
stream_format: str,
|
|
1048
|
+
) -> None:
|
|
1049
|
+
"""Handle prompt learning job creation (MIPRO or GEPA)."""
|
|
1050
|
+
import os
|
|
1051
|
+
|
|
1052
|
+
overrides: dict[str, Any] = {
|
|
1053
|
+
"backend": backend_base,
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
build = build_prompt_learning_payload(
|
|
1057
|
+
config_path=cfg_path,
|
|
1058
|
+
task_url=None, # Force using TOML only
|
|
1059
|
+
overrides=overrides,
|
|
1060
|
+
allow_experimental=allow_experimental,
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
env_key = os.environ.get("ENVIRONMENT_API_KEY")
|
|
1064
|
+
if not env_key:
|
|
1065
|
+
raise click.ClickException("ENVIRONMENT_API_KEY required for prompt learning flow")
|
|
1066
|
+
|
|
1067
|
+
click.echo("Performing task app health check…")
|
|
1068
|
+
health = check_task_app_health(build.task_url, env_key)
|
|
1069
|
+
if not health.ok:
|
|
1070
|
+
click.echo(f"Task app health check failed: {health.detail}")
|
|
1071
|
+
raise click.ClickException("Aborting due to failing health check")
|
|
1072
|
+
else:
|
|
1073
|
+
click.echo("Task app healthy")
|
|
1074
|
+
|
|
1075
|
+
create_url = f"{backend_base}/prompt-learning/online/jobs"
|
|
1076
|
+
headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
|
|
1077
|
+
|
|
1078
|
+
click.echo(f"POST {create_url}")
|
|
1079
|
+
click.echo("Payload preview:\n" + preview_json(build.payload, limit=800))
|
|
1080
|
+
|
|
1081
|
+
resp = http_post(create_url, headers=headers, json_body=build.payload)
|
|
1082
|
+
try:
|
|
1083
|
+
js = resp.json()
|
|
1084
|
+
except json.JSONDecodeError as e:
|
|
1085
|
+
click.echo(f"⚠️ Failed to parse JSON response: {e}")
|
|
1086
|
+
js = {"status": resp.status_code, "text": resp.text[:400]}
|
|
1087
|
+
click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
|
|
1088
|
+
if resp.status_code not in (200, 201):
|
|
1089
|
+
raise click.ClickException("Job creation failed")
|
|
1090
|
+
job_id = js.get("job_id") or js.get("id")
|
|
1091
|
+
if not job_id:
|
|
1092
|
+
raise click.ClickException("Response missing job id")
|
|
1093
|
+
|
|
1094
|
+
if not poll:
|
|
1095
|
+
click.echo(f"Created job {job_id} (polling disabled)")
|
|
1096
|
+
return
|
|
1097
|
+
|
|
1098
|
+
click.echo("\n=== Streaming Job Progress ===")
|
|
1099
|
+
|
|
1100
|
+
# Custom config for prompt learning to enable metrics
|
|
1101
|
+
if stream_format == "chart":
|
|
1102
|
+
config = StreamConfig(
|
|
1103
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
|
|
1104
|
+
event_types={
|
|
1105
|
+
"prompt.learning.progress",
|
|
1106
|
+
"prompt.learning.gepa.start",
|
|
1107
|
+
"prompt.learning.gepa.complete",
|
|
1108
|
+
},
|
|
1109
|
+
metric_names={"gepa.transformation.mean_score"},
|
|
1110
|
+
)
|
|
1111
|
+
handlers = [LossCurveHandler()]
|
|
1112
|
+
click.echo("Using live loss chart (metric=gepa.transformation.mean_score)")
|
|
1113
|
+
else:
|
|
1114
|
+
# Enable metrics for CLI mode too
|
|
1115
|
+
config = StreamConfig(
|
|
1116
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
|
|
1117
|
+
metric_names={"gepa.transformation.mean_score"},
|
|
1118
|
+
)
|
|
1119
|
+
handlers = [CLIHandler(
|
|
1120
|
+
hidden_event_types=_DEFAULT_PROMPT_LEARNING_HIDDEN_EVENTS,
|
|
1121
|
+
hidden_event_substrings=_DEFAULT_RL_HIDDEN_SUBSTRINGS,
|
|
1122
|
+
)]
|
|
1123
|
+
|
|
1124
|
+
streamer = JobStreamer(
|
|
1125
|
+
base_url=backend_base,
|
|
1126
|
+
api_key=synth_key,
|
|
1127
|
+
job_id=job_id,
|
|
1128
|
+
endpoints=StreamEndpoints.prompt_learning(job_id),
|
|
1129
|
+
config=config,
|
|
1130
|
+
handlers=handlers,
|
|
1131
|
+
interval_seconds=poll_interval,
|
|
1132
|
+
timeout_seconds=poll_timeout,
|
|
1133
|
+
)
|
|
1134
|
+
final_status = asyncio.run(streamer.stream_until_terminal())
|
|
1135
|
+
click.echo(f"Final status: {final_status.get('status', 'unknown')}")
|
|
1136
|
+
click.echo(preview_json(final_status, limit=600))
|
|
1137
|
+
|
|
1138
|
+
# Save results file locally
|
|
1139
|
+
_save_prompt_learning_results_locally(
|
|
1140
|
+
backend_base=backend_base,
|
|
1141
|
+
api_key=synth_key,
|
|
1142
|
+
job_id=job_id,
|
|
1143
|
+
config_path=cfg_path,
|
|
1144
|
+
)
|
|
655
1145
|
|
|
656
1146
|
|
|
657
1147
|
def register(cli: click.Group) -> None:
|