synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
- examples/rl/configs/rl_from_base_qwen17.toml +1 -0
- examples/swe/task_app/hosted/inference/openai_client.py +0 -34
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/task_app.py +254 -36
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
- synth_ai/api/train/builders.py +90 -1
- synth_ai/api/train/cli.py +396 -21
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +15 -1
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +29 -0
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +85 -17
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +1 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/eval/core.py +13 -10
- synth_ai/cli/commands/filter/core.py +53 -17
- synth_ai/cli/commands/help/core.py +0 -1
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/train/judge_schemas.py +1 -0
- synth_ai/cli/commands/train/judge_validation.py +1 -0
- synth_ai/cli/commands/train/validation.py +0 -57
- synth_ai/cli/demo.py +35 -3
- synth_ai/cli/deploy/__init__.py +40 -25
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/task_app_deploy.py +1 -1
- synth_ai/cli/task_apps.py +53 -53
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/judge_schemas.py +1 -0
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/handlers.py +53 -4
- synth_ai/streaming/streamer.py +19 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +44 -8
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +17 -17
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +283 -1
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
- synth_ai/cli/commands/deploy/__init__.py +0 -23
- synth_ai/cli/commands/deploy/core.py +0 -614
- synth_ai/cli/commands/deploy/errors.py +0 -72
- synth_ai/cli/commands/deploy/validation.py +0 -11
- synth_ai/cli/deploy/core.py +0 -5
- synth_ai/cli/deploy/errors.py +0 -23
- synth_ai/cli/deploy/validation.py +0 -5
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
synth_ai/api/train/cli.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import contextlib
|
|
4
5
|
import importlib
|
|
6
|
+
import json
|
|
5
7
|
import os
|
|
6
8
|
import time
|
|
7
9
|
from collections.abc import Callable, Mapping
|
|
@@ -27,7 +29,7 @@ from synth_ai.streaming import (
|
|
|
27
29
|
StreamType,
|
|
28
30
|
)
|
|
29
31
|
|
|
30
|
-
from .builders import build_rl_payload, build_sft_payload
|
|
32
|
+
from .builders import build_prompt_learning_payload, build_rl_payload, build_sft_payload
|
|
31
33
|
from .config_finder import discover_configs, prompt_for_config
|
|
32
34
|
from .env_resolver import KeySpec, resolve_env
|
|
33
35
|
from .task_app import check_task_app_health
|
|
@@ -45,6 +47,45 @@ from .utils import (
|
|
|
45
47
|
validate_sft_jsonl,
|
|
46
48
|
)
|
|
47
49
|
|
|
50
|
+
# Constants for prompt learning event types
|
|
51
|
+
_PROMPT_LEARNING_EVENT_BEST_PROMPT = "prompt.learning.best.prompt"
|
|
52
|
+
_PROMPT_LEARNING_EVENT_FINAL_RESULTS = "prompt.learning.final.results"
|
|
53
|
+
_PROMPT_LEARNING_EVENT_VALIDATION_SCORED = "prompt.learning.validation.scored"
|
|
54
|
+
_PROMPT_LEARNING_EVENT_GEPA_COMPLETE = "prompt.learning.gepa.complete"
|
|
55
|
+
|
|
56
|
+
# Constants for formatting
|
|
57
|
+
_MAX_TEXT_REPLACEMENTS_DISPLAY = 3 # Max number of text replacements to show in output
|
|
58
|
+
_RESULTS_FILE_MAX_EVENTS = 10000 # Max events to fetch for results file generation
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _format_text_replacements(obj: dict[str, Any] | None, max_display: int = _MAX_TEXT_REPLACEMENTS_DISPLAY) -> list[str]:
|
|
62
|
+
"""Extract and format text replacements from a candidate object.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
obj: Candidate object dictionary containing text_replacements
|
|
66
|
+
max_display: Maximum number of replacements to display
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
List of formatted lines showing role and replacement text
|
|
70
|
+
"""
|
|
71
|
+
lines = []
|
|
72
|
+
if not obj or not isinstance(obj, dict):
|
|
73
|
+
return lines
|
|
74
|
+
|
|
75
|
+
text_replacements = obj.get("text_replacements", [])
|
|
76
|
+
if not text_replacements or not isinstance(text_replacements, list):
|
|
77
|
+
return lines
|
|
78
|
+
|
|
79
|
+
for replacement in text_replacements[:max_display]:
|
|
80
|
+
if isinstance(replacement, dict):
|
|
81
|
+
new_text = replacement.get("new_text", "")
|
|
82
|
+
role = replacement.get("apply_to_role", "system")
|
|
83
|
+
if new_text:
|
|
84
|
+
lines.append(f" [{role.upper()}]: {new_text}")
|
|
85
|
+
lines.append("")
|
|
86
|
+
|
|
87
|
+
return lines
|
|
88
|
+
|
|
48
89
|
|
|
49
90
|
def _discover_dataset_candidates(
|
|
50
91
|
config_path: Path, limit: int = 50, timeout: float = 10.0
|
|
@@ -164,6 +205,10 @@ _DEFAULT_SFT_HIDDEN_EVENTS = {
|
|
|
164
205
|
|
|
165
206
|
_DEFAULT_RL_HIDDEN_SUBSTRINGS = {"modal", "hatchet"}
|
|
166
207
|
|
|
208
|
+
_DEFAULT_PROMPT_LEARNING_HIDDEN_EVENTS = {
|
|
209
|
+
"prompt.learning.policy.tokens",
|
|
210
|
+
}
|
|
211
|
+
|
|
167
212
|
|
|
168
213
|
def _build_stream_components(
|
|
169
214
|
stream_format: str,
|
|
@@ -208,7 +253,7 @@ def _build_stream_components(
|
|
|
208
253
|
type=click.Path(),
|
|
209
254
|
help="Path to training TOML (repeatable)",
|
|
210
255
|
)
|
|
211
|
-
@click.option("--type", "train_type", type=click.Choice(["auto", "rl", "sft"]), default="auto")
|
|
256
|
+
@click.option("--type", "train_type", type=click.Choice(["auto", "rl", "sft", "prompt_learning"]), default="auto")
|
|
212
257
|
@click.option(
|
|
213
258
|
"--env-file",
|
|
214
259
|
"env_files",
|
|
@@ -279,7 +324,7 @@ def train_command(
|
|
|
279
324
|
stream_format: str,
|
|
280
325
|
examples_limit: int | None,
|
|
281
326
|
) -> None:
|
|
282
|
-
"""Interactive launcher for RL / SFT jobs."""
|
|
327
|
+
"""Interactive launcher for RL / SFT / Prompt Learning jobs."""
|
|
283
328
|
|
|
284
329
|
candidates = discover_configs(
|
|
285
330
|
list(config_paths), requested_type=train_type if train_type != "auto" else None
|
|
@@ -291,16 +336,16 @@ def train_command(
|
|
|
291
336
|
)
|
|
292
337
|
|
|
293
338
|
effective_type = train_type if train_type != "auto" else selection.train_type
|
|
294
|
-
if effective_type not in {"rl", "sft"}:
|
|
339
|
+
if effective_type not in {"rl", "sft", "prompt_learning"}:
|
|
295
340
|
effective_type = click.prompt(
|
|
296
|
-
"Detected config type is ambiguous. Enter type", type=click.Choice(["rl", "sft"])
|
|
341
|
+
"Detected config type is ambiguous. Enter type", type=click.Choice(["rl", "sft", "prompt_learning"])
|
|
297
342
|
)
|
|
298
343
|
|
|
299
344
|
cfg_path = selection.path
|
|
300
345
|
click.echo(f"Using config: {cfg_path} ({effective_type})")
|
|
301
346
|
|
|
302
347
|
required_keys: list[KeySpec] = []
|
|
303
|
-
if effective_type == "rl":
|
|
348
|
+
if effective_type == "rl" or effective_type == "prompt_learning":
|
|
304
349
|
required_keys.append(KeySpec("SYNTH_API_KEY", "Synth API key for backend"))
|
|
305
350
|
required_keys.append(
|
|
306
351
|
KeySpec(
|
|
@@ -377,6 +422,19 @@ def train_command(
|
|
|
377
422
|
poll_interval=poll_interval,
|
|
378
423
|
stream_format=stream_format,
|
|
379
424
|
)
|
|
425
|
+
elif effective_type == "prompt_learning":
|
|
426
|
+
handle_prompt_learning(
|
|
427
|
+
cfg_path=cfg_path,
|
|
428
|
+
backend_base=backend_base,
|
|
429
|
+
synth_key=synth_key,
|
|
430
|
+
task_url_override=task_url,
|
|
431
|
+
allow_experimental=allow_experimental,
|
|
432
|
+
dry_run=dry_run,
|
|
433
|
+
poll=poll,
|
|
434
|
+
poll_timeout=poll_timeout,
|
|
435
|
+
poll_interval=poll_interval,
|
|
436
|
+
stream_format=stream_format,
|
|
437
|
+
)
|
|
380
438
|
else:
|
|
381
439
|
dataset_override_path = Path(dataset_path).expanduser().resolve() if dataset_path else None
|
|
382
440
|
handle_sft(
|
|
@@ -415,7 +473,7 @@ def _wait_for_training_file(
|
|
|
415
473
|
if resp.status_code == 200:
|
|
416
474
|
try:
|
|
417
475
|
data = resp.json()
|
|
418
|
-
except
|
|
476
|
+
except json.JSONDecodeError:
|
|
419
477
|
data = {}
|
|
420
478
|
status = str(
|
|
421
479
|
data.get("status") or data.get("state") or data.get("storage_state") or "ready"
|
|
@@ -440,7 +498,7 @@ def _wait_for_training_file(
|
|
|
440
498
|
# Auth errors won't resolve by polling - fail immediately
|
|
441
499
|
try:
|
|
442
500
|
error_body = resp.json()
|
|
443
|
-
except
|
|
501
|
+
except json.JSONDecodeError:
|
|
444
502
|
error_body = resp.text[:400]
|
|
445
503
|
click.echo("\n[ERROR] Authentication failed when checking training file:")
|
|
446
504
|
click.echo(f" URL: {url}")
|
|
@@ -455,7 +513,7 @@ def _wait_for_training_file(
|
|
|
455
513
|
# Other errors - show details but keep polling
|
|
456
514
|
try:
|
|
457
515
|
error_body = resp.json()
|
|
458
|
-
except
|
|
516
|
+
except json.JSONDecodeError:
|
|
459
517
|
error_body = resp.text[:400]
|
|
460
518
|
click.echo(f"[WARN] Unexpected response checking file {file_id}:")
|
|
461
519
|
click.echo(f" URL: {url}")
|
|
@@ -507,7 +565,7 @@ def handle_rl(
|
|
|
507
565
|
)
|
|
508
566
|
try:
|
|
509
567
|
parsed_json = vresp.json()
|
|
510
|
-
except
|
|
568
|
+
except json.JSONDecodeError:
|
|
511
569
|
parsed_json = None
|
|
512
570
|
|
|
513
571
|
if isinstance(parsed_json, Mapping):
|
|
@@ -542,8 +600,9 @@ def handle_rl(
|
|
|
542
600
|
)
|
|
543
601
|
statuses = [attempt.get("status") for attempt in attempts]
|
|
544
602
|
click.echo(f"Verification OK (candidates={cands}, statuses={statuses})")
|
|
545
|
-
except
|
|
546
|
-
|
|
603
|
+
except (KeyError, ValueError, AttributeError):
|
|
604
|
+
# Parsing verification summary failed, but verification itself succeeded
|
|
605
|
+
click.echo("Verification OK")
|
|
547
606
|
|
|
548
607
|
env_key = os.environ.get("ENVIRONMENT_API_KEY")
|
|
549
608
|
if not env_key:
|
|
@@ -568,7 +627,8 @@ def handle_rl(
|
|
|
568
627
|
resp = http_post(create_url, headers=headers, json_body=build.payload)
|
|
569
628
|
try:
|
|
570
629
|
js = resp.json()
|
|
571
|
-
except
|
|
630
|
+
except json.JSONDecodeError as e:
|
|
631
|
+
click.echo(f"⚠️ Failed to parse JSON response: {e}")
|
|
572
632
|
js = {"status": resp.status_code, "text": resp.text[:400]}
|
|
573
633
|
click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
|
|
574
634
|
if resp.status_code not in (200, 201):
|
|
@@ -582,11 +642,27 @@ def handle_rl(
|
|
|
582
642
|
return
|
|
583
643
|
|
|
584
644
|
click.echo("\n=== Streaming Job Progress ===")
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
)
|
|
645
|
+
|
|
646
|
+
# Enable metrics for prompt learning
|
|
588
647
|
if stream_format == "chart":
|
|
589
|
-
|
|
648
|
+
config = StreamConfig(
|
|
649
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
|
|
650
|
+
event_types={
|
|
651
|
+
"prompt.learning.progress",
|
|
652
|
+
"prompt.learning.gepa.start",
|
|
653
|
+
"prompt.learning.gepa.complete",
|
|
654
|
+
},
|
|
655
|
+
metric_names={"gepa.transformation.mean_score"},
|
|
656
|
+
)
|
|
657
|
+
handlers = [LossCurveHandler()]
|
|
658
|
+
click.echo("Using live chart (metric=gepa.transformation.mean_score)")
|
|
659
|
+
else:
|
|
660
|
+
config = StreamConfig(
|
|
661
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
|
|
662
|
+
metric_names={"gepa.transformation.mean_score"},
|
|
663
|
+
)
|
|
664
|
+
handlers = [CLIHandler(hidden_event_substrings=_DEFAULT_RL_HIDDEN_SUBSTRINGS)]
|
|
665
|
+
|
|
590
666
|
streamer = JobStreamer(
|
|
591
667
|
base_url=backend_base,
|
|
592
668
|
api_key=synth_key,
|
|
@@ -758,15 +834,314 @@ def handle_sft(
|
|
|
758
834
|
timeout_seconds=poll_timeout,
|
|
759
835
|
)
|
|
760
836
|
final_status = asyncio.run(streamer.stream_until_terminal())
|
|
761
|
-
|
|
837
|
+
status = final_status.get('status') if isinstance(final_status, dict) else 'unknown'
|
|
838
|
+
click.echo(f"Final status: {status}")
|
|
762
839
|
click.echo(preview_json(final_status, limit=600))
|
|
763
840
|
finally:
|
|
764
841
|
if limited_path is not None:
|
|
765
|
-
|
|
842
|
+
with contextlib.suppress(OSError):
|
|
766
843
|
limited_path.unlink(missing_ok=True)
|
|
844
|
+
# Clean up empty parent directory if possible
|
|
845
|
+
with contextlib.suppress(OSError):
|
|
767
846
|
limited_path.parent.rmdir()
|
|
768
|
-
|
|
769
|
-
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def _save_prompt_learning_results_locally(
|
|
850
|
+
*,
|
|
851
|
+
backend_base: str,
|
|
852
|
+
api_key: str,
|
|
853
|
+
job_id: str,
|
|
854
|
+
config_path: Path,
|
|
855
|
+
) -> None:
|
|
856
|
+
"""Fetch events and generate results file locally after prompt learning completes."""
|
|
857
|
+
from datetime import datetime
|
|
858
|
+
|
|
859
|
+
try:
|
|
860
|
+
# Fetch all events
|
|
861
|
+
url = f"{backend_base}/prompt-learning/online/jobs/{job_id}/events?limit={_RESULTS_FILE_MAX_EVENTS}"
|
|
862
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
863
|
+
resp = http_get(url, headers=headers, timeout=30.0)
|
|
864
|
+
|
|
865
|
+
if resp.status_code != 200:
|
|
866
|
+
click.echo(f"⚠️ Could not fetch events to generate results file (status={resp.status_code})")
|
|
867
|
+
return
|
|
868
|
+
|
|
869
|
+
data = resp.json()
|
|
870
|
+
# Validate response structure
|
|
871
|
+
if not isinstance(data, dict):
|
|
872
|
+
click.echo(f"⚠️ Unexpected response type: {type(data).__name__}")
|
|
873
|
+
return
|
|
874
|
+
|
|
875
|
+
events = data.get("events", [])
|
|
876
|
+
if not isinstance(events, list):
|
|
877
|
+
click.echo(f"⚠️ Events field is not a list: {type(events).__name__}")
|
|
878
|
+
return
|
|
879
|
+
|
|
880
|
+
if not events:
|
|
881
|
+
return
|
|
882
|
+
|
|
883
|
+
# Extract key data from events
|
|
884
|
+
best_score = None
|
|
885
|
+
best_prompt = None
|
|
886
|
+
baseline_score = None
|
|
887
|
+
attempted_candidates = []
|
|
888
|
+
optimized_candidates = []
|
|
889
|
+
|
|
890
|
+
for event in events:
|
|
891
|
+
if not isinstance(event, dict):
|
|
892
|
+
continue # Skip malformed events
|
|
893
|
+
|
|
894
|
+
event_type = event.get("type", "")
|
|
895
|
+
event_data = event.get("data", {})
|
|
896
|
+
if not isinstance(event_data, dict):
|
|
897
|
+
event_data = {} # Fallback to empty dict for safety
|
|
898
|
+
|
|
899
|
+
if event_type == _PROMPT_LEARNING_EVENT_BEST_PROMPT:
|
|
900
|
+
best_score = event_data.get("best_score")
|
|
901
|
+
best_prompt = event_data.get("best_prompt")
|
|
902
|
+
elif event_type == _PROMPT_LEARNING_EVENT_FINAL_RESULTS:
|
|
903
|
+
attempted_candidates = event_data.get("attempted_candidates", [])
|
|
904
|
+
optimized_candidates = event_data.get("optimized_candidates", [])
|
|
905
|
+
elif event_type == _PROMPT_LEARNING_EVENT_VALIDATION_SCORED:
|
|
906
|
+
# Check if this is the baseline by checking for is_baseline flag or baseline in message
|
|
907
|
+
is_baseline = event_data.get("is_baseline", False)
|
|
908
|
+
if not is_baseline:
|
|
909
|
+
msg = event.get("message", "")
|
|
910
|
+
is_baseline = "baseline" in msg.lower()
|
|
911
|
+
if is_baseline:
|
|
912
|
+
baseline_score = event_data.get("accuracy")
|
|
913
|
+
elif event_type == _PROMPT_LEARNING_EVENT_GEPA_COMPLETE and best_score is None:
|
|
914
|
+
best_score = event_data.get("best_score")
|
|
915
|
+
|
|
916
|
+
if not (attempted_candidates or optimized_candidates):
|
|
917
|
+
return
|
|
918
|
+
|
|
919
|
+
# Generate formatted report
|
|
920
|
+
lines = []
|
|
921
|
+
lines.append("=" * 80)
|
|
922
|
+
lines.append("GEPA PROMPT LEARNING RESULTS")
|
|
923
|
+
lines.append("=" * 80)
|
|
924
|
+
lines.append(f"Job ID: {job_id}")
|
|
925
|
+
lines.append(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
926
|
+
lines.append("")
|
|
927
|
+
if baseline_score is not None:
|
|
928
|
+
lines.append(f"📊 Baseline Score: {baseline_score:.4f} ({baseline_score*100:.1f}%)")
|
|
929
|
+
if best_score is not None:
|
|
930
|
+
lines.append(f"🏆 Best Score: {best_score:.4f} ({best_score*100:.1f}%)")
|
|
931
|
+
if baseline_score is not None and best_score is not None:
|
|
932
|
+
improvement = ((best_score - baseline_score) / baseline_score) * 100 if baseline_score > 0 else 0
|
|
933
|
+
lines.append(f"📈 Improvement: {improvement:+.1f}% relative ({(best_score - baseline_score)*100:+.1f} pp absolute)")
|
|
934
|
+
lines.append("=" * 80)
|
|
935
|
+
lines.append("")
|
|
936
|
+
|
|
937
|
+
# Add best prompt if available
|
|
938
|
+
if best_prompt and isinstance(best_prompt, dict):
|
|
939
|
+
lines.append("🏆 BEST PROMPT")
|
|
940
|
+
lines.append("-" * 80)
|
|
941
|
+
sections = best_prompt.get("sections", [])
|
|
942
|
+
if not isinstance(sections, list):
|
|
943
|
+
sections = []
|
|
944
|
+
for sec in sections:
|
|
945
|
+
if not isinstance(sec, dict):
|
|
946
|
+
continue
|
|
947
|
+
role = sec.get("role", "unknown")
|
|
948
|
+
content = sec.get("content", "")
|
|
949
|
+
lines.append(f"\n[{role.upper()}]:")
|
|
950
|
+
lines.append(content)
|
|
951
|
+
lines.append("")
|
|
952
|
+
|
|
953
|
+
# Add optimized candidates
|
|
954
|
+
if optimized_candidates and isinstance(optimized_candidates, list):
|
|
955
|
+
lines.append("=" * 80)
|
|
956
|
+
lines.append(f"✨ TOP OPTIMIZED CANDIDATES ({len(optimized_candidates)})")
|
|
957
|
+
lines.append("=" * 80)
|
|
958
|
+
lines.append("")
|
|
959
|
+
|
|
960
|
+
for idx, cand in enumerate(optimized_candidates):
|
|
961
|
+
if not isinstance(cand, dict):
|
|
962
|
+
continue
|
|
963
|
+
candidate_score = cand.get("score") or {}
|
|
964
|
+
accuracy = candidate_score.get("accuracy", 0.0)
|
|
965
|
+
prompt_length = candidate_score.get("prompt_length", 0)
|
|
966
|
+
payload_kind = cand.get("payload_kind", "unknown")
|
|
967
|
+
|
|
968
|
+
# Try score.instance_scores first, then cand.instance_scores (explicit check)
|
|
969
|
+
instance_scores = (
|
|
970
|
+
candidate_score.get('instance_scores')
|
|
971
|
+
if 'instance_scores' in candidate_score
|
|
972
|
+
else cand.get('instance_scores')
|
|
973
|
+
)
|
|
974
|
+
n_eval = len(instance_scores) if instance_scores and isinstance(instance_scores, list) else 0
|
|
975
|
+
|
|
976
|
+
lines.append(f"[{idx+1}] Accuracy: {accuracy:.4f} | Length: {prompt_length} | Type: {payload_kind} | N: {n_eval}")
|
|
977
|
+
lines.append("-" * 80)
|
|
978
|
+
|
|
979
|
+
obj = cand.get("object")
|
|
980
|
+
if obj and isinstance(obj, dict) and payload_kind == "transformation":
|
|
981
|
+
# For transformations, text_replacements are nested in data
|
|
982
|
+
data_obj = obj.get("data", {})
|
|
983
|
+
replacement_lines = _format_text_replacements(data_obj)
|
|
984
|
+
lines.extend(replacement_lines)
|
|
985
|
+
lines.append("")
|
|
986
|
+
|
|
987
|
+
# Add all proposal candidates
|
|
988
|
+
if attempted_candidates and isinstance(attempted_candidates, list):
|
|
989
|
+
lines.append("=" * 80)
|
|
990
|
+
lines.append(f"💡 ALL PROPOSAL CANDIDATES ({len(attempted_candidates)})")
|
|
991
|
+
lines.append("=" * 80)
|
|
992
|
+
lines.append("")
|
|
993
|
+
|
|
994
|
+
for idx, cand in enumerate(attempted_candidates):
|
|
995
|
+
if not isinstance(cand, dict):
|
|
996
|
+
continue
|
|
997
|
+
accuracy = cand.get('accuracy', 0.0)
|
|
998
|
+
prompt_length = cand.get('prompt_length', 0)
|
|
999
|
+
tool_rate = cand.get('tool_call_rate', 0.0)
|
|
1000
|
+
instance_scores = cand.get('instance_scores', [])
|
|
1001
|
+
n_eval = len(instance_scores) if instance_scores else 0
|
|
1002
|
+
|
|
1003
|
+
lines.append(f"[{idx+1}] Accuracy: {accuracy:.4f} | Length: {prompt_length} | Tool Rate: {tool_rate:.2f} | N: {n_eval}")
|
|
1004
|
+
lines.append("-" * 80)
|
|
1005
|
+
|
|
1006
|
+
obj = cand.get("object")
|
|
1007
|
+
if obj and isinstance(obj, dict):
|
|
1008
|
+
# For proposals, text_replacements are at top level of object
|
|
1009
|
+
replacement_lines = _format_text_replacements(obj)
|
|
1010
|
+
lines.extend(replacement_lines)
|
|
1011
|
+
lines.append("")
|
|
1012
|
+
|
|
1013
|
+
lines.append("=" * 80)
|
|
1014
|
+
lines.append("END OF REPORT")
|
|
1015
|
+
lines.append("=" * 80)
|
|
1016
|
+
|
|
1017
|
+
# Determine save location
|
|
1018
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1019
|
+
|
|
1020
|
+
# Try to save in config directory first
|
|
1021
|
+
output_dir = config_path.parent / "results"
|
|
1022
|
+
output_dir.mkdir(exist_ok=True)
|
|
1023
|
+
output_file = output_dir / f"gepa_results_{job_id}_{timestamp}.txt"
|
|
1024
|
+
|
|
1025
|
+
with open(output_file, "w", encoding="utf-8") as f:
|
|
1026
|
+
f.write("\n".join(lines))
|
|
1027
|
+
|
|
1028
|
+
click.echo(f"\n📄 Results saved locally to: {output_file}")
|
|
1029
|
+
|
|
1030
|
+
except (PermissionError, OSError) as e:
|
|
1031
|
+
click.echo(f"⚠️ Could not save results file locally: {e}")
|
|
1032
|
+
except Exception as e:
|
|
1033
|
+
click.echo(f"⚠️ Unexpected error saving results file: {e}")
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
def handle_prompt_learning(
|
|
1037
|
+
*,
|
|
1038
|
+
cfg_path: Path,
|
|
1039
|
+
backend_base: str,
|
|
1040
|
+
synth_key: str,
|
|
1041
|
+
task_url_override: str | None,
|
|
1042
|
+
allow_experimental: bool | None,
|
|
1043
|
+
dry_run: bool,
|
|
1044
|
+
poll: bool,
|
|
1045
|
+
poll_timeout: float,
|
|
1046
|
+
poll_interval: float,
|
|
1047
|
+
stream_format: str,
|
|
1048
|
+
) -> None:
|
|
1049
|
+
"""Handle prompt learning job creation (MIPRO or GEPA)."""
|
|
1050
|
+
import os
|
|
1051
|
+
|
|
1052
|
+
overrides: dict[str, Any] = {
|
|
1053
|
+
"backend": backend_base,
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
build = build_prompt_learning_payload(
|
|
1057
|
+
config_path=cfg_path,
|
|
1058
|
+
task_url=None, # Force using TOML only
|
|
1059
|
+
overrides=overrides,
|
|
1060
|
+
allow_experimental=allow_experimental,
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
env_key = os.environ.get("ENVIRONMENT_API_KEY")
|
|
1064
|
+
if not env_key:
|
|
1065
|
+
raise click.ClickException("ENVIRONMENT_API_KEY required for prompt learning flow")
|
|
1066
|
+
|
|
1067
|
+
click.echo("Performing task app health check…")
|
|
1068
|
+
health = check_task_app_health(build.task_url, env_key)
|
|
1069
|
+
if not health.ok:
|
|
1070
|
+
click.echo(f"Task app health check failed: {health.detail}")
|
|
1071
|
+
raise click.ClickException("Aborting due to failing health check")
|
|
1072
|
+
else:
|
|
1073
|
+
click.echo("Task app healthy")
|
|
1074
|
+
|
|
1075
|
+
create_url = f"{backend_base}/prompt-learning/online/jobs"
|
|
1076
|
+
headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
|
|
1077
|
+
|
|
1078
|
+
click.echo(f"POST {create_url}")
|
|
1079
|
+
click.echo("Payload preview:\n" + preview_json(build.payload, limit=800))
|
|
1080
|
+
|
|
1081
|
+
resp = http_post(create_url, headers=headers, json_body=build.payload)
|
|
1082
|
+
try:
|
|
1083
|
+
js = resp.json()
|
|
1084
|
+
except json.JSONDecodeError as e:
|
|
1085
|
+
click.echo(f"⚠️ Failed to parse JSON response: {e}")
|
|
1086
|
+
js = {"status": resp.status_code, "text": resp.text[:400]}
|
|
1087
|
+
click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
|
|
1088
|
+
if resp.status_code not in (200, 201):
|
|
1089
|
+
raise click.ClickException("Job creation failed")
|
|
1090
|
+
job_id = js.get("job_id") or js.get("id")
|
|
1091
|
+
if not job_id:
|
|
1092
|
+
raise click.ClickException("Response missing job id")
|
|
1093
|
+
|
|
1094
|
+
if not poll:
|
|
1095
|
+
click.echo(f"Created job {job_id} (polling disabled)")
|
|
1096
|
+
return
|
|
1097
|
+
|
|
1098
|
+
click.echo("\n=== Streaming Job Progress ===")
|
|
1099
|
+
|
|
1100
|
+
# Custom config for prompt learning to enable metrics
|
|
1101
|
+
if stream_format == "chart":
|
|
1102
|
+
config = StreamConfig(
|
|
1103
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
|
|
1104
|
+
event_types={
|
|
1105
|
+
"prompt.learning.progress",
|
|
1106
|
+
"prompt.learning.gepa.start",
|
|
1107
|
+
"prompt.learning.gepa.complete",
|
|
1108
|
+
},
|
|
1109
|
+
metric_names={"gepa.transformation.mean_score"},
|
|
1110
|
+
)
|
|
1111
|
+
handlers = [LossCurveHandler()]
|
|
1112
|
+
click.echo("Using live loss chart (metric=gepa.transformation.mean_score)")
|
|
1113
|
+
else:
|
|
1114
|
+
# Enable metrics for CLI mode too
|
|
1115
|
+
config = StreamConfig(
|
|
1116
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
|
|
1117
|
+
metric_names={"gepa.transformation.mean_score"},
|
|
1118
|
+
)
|
|
1119
|
+
handlers = [CLIHandler(
|
|
1120
|
+
hidden_event_types=_DEFAULT_PROMPT_LEARNING_HIDDEN_EVENTS,
|
|
1121
|
+
hidden_event_substrings=_DEFAULT_RL_HIDDEN_SUBSTRINGS,
|
|
1122
|
+
)]
|
|
1123
|
+
|
|
1124
|
+
streamer = JobStreamer(
|
|
1125
|
+
base_url=backend_base,
|
|
1126
|
+
api_key=synth_key,
|
|
1127
|
+
job_id=job_id,
|
|
1128
|
+
endpoints=StreamEndpoints.prompt_learning(job_id),
|
|
1129
|
+
config=config,
|
|
1130
|
+
handlers=handlers,
|
|
1131
|
+
interval_seconds=poll_interval,
|
|
1132
|
+
timeout_seconds=poll_timeout,
|
|
1133
|
+
)
|
|
1134
|
+
final_status = asyncio.run(streamer.stream_until_terminal())
|
|
1135
|
+
click.echo(f"Final status: {final_status.get('status', 'unknown')}")
|
|
1136
|
+
click.echo(preview_json(final_status, limit=600))
|
|
1137
|
+
|
|
1138
|
+
# Save results file locally
|
|
1139
|
+
_save_prompt_learning_results_locally(
|
|
1140
|
+
backend_base=backend_base,
|
|
1141
|
+
api_key=synth_key,
|
|
1142
|
+
job_id=job_id,
|
|
1143
|
+
config_path=cfg_path,
|
|
1144
|
+
)
|
|
770
1145
|
|
|
771
1146
|
|
|
772
1147
|
def register(cli: click.Group) -> None:
|
|
@@ -18,7 +18,7 @@ _STATE_FILE = _STATE_DIR / "train_cli.json"
|
|
|
18
18
|
@dataclass(slots=True)
|
|
19
19
|
class ConfigCandidate:
|
|
20
20
|
path: Path
|
|
21
|
-
train_type: str # "rl", "sft", or "unknown"
|
|
21
|
+
train_type: str # "rl", "sft", "prompt_learning", or "unknown"
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def _load_last_config() -> Path | None:
|
|
@@ -94,6 +94,17 @@ def _iter_candidate_paths() -> Iterable[Path]:
|
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
def _infer_config_type(data: dict) -> str:
|
|
97
|
+
# 0) Check for prompt_learning section (highest priority)
|
|
98
|
+
pl_section = data.get("prompt_learning")
|
|
99
|
+
if isinstance(pl_section, dict):
|
|
100
|
+
algorithm = pl_section.get("algorithm", "").lower()
|
|
101
|
+
if algorithm in {"mipro", "gepa"}:
|
|
102
|
+
return "prompt_learning"
|
|
103
|
+
# Also check if top-level has prompt_learning indicators
|
|
104
|
+
algorithm = data.get("algorithm")
|
|
105
|
+
if isinstance(algorithm, str) and algorithm.lower() in {"mipro", "gepa"}:
|
|
106
|
+
return "prompt_learning"
|
|
107
|
+
|
|
97
108
|
# 1) Strong signals from [algorithm]
|
|
98
109
|
algo = data.get("algorithm")
|
|
99
110
|
if isinstance(algo, dict):
|
|
@@ -152,7 +163,7 @@ def discover_configs(explicit: list[str], *, requested_type: str | None) -> list
|
|
|
152
163
|
cfg_type = _infer_config_type(data)
|
|
153
164
|
if cfg_type == "unknown":
|
|
154
165
|
raise click.ClickException(
|
|
155
|
-
f"Config {path} is missing algorithm.type/method metadata. Add type = 'rl' or '
|
|
166
|
+
f"Config {path} is missing algorithm.type/method metadata. Add type = 'rl', 'sft', or 'prompt_learning'."
|
|
156
167
|
)
|
|
157
168
|
candidates.append(ConfigCandidate(path=path, train_type=cfg_type))
|
|
158
169
|
seen.add(path)
|
|
@@ -1,5 +1,13 @@
|
|
|
1
|
-
"""Typed training config loaders for RL and
|
|
1
|
+
"""Typed training config loaders for RL, SFT, and Prompt Learning jobs."""
|
|
2
2
|
|
|
3
|
+
from .prompt_learning import (
|
|
4
|
+
GEPAConfig,
|
|
5
|
+
MessagePatternConfig,
|
|
6
|
+
MIPROConfig,
|
|
7
|
+
PromptLearningConfig,
|
|
8
|
+
PromptLearningPolicyConfig,
|
|
9
|
+
PromptPatternConfig,
|
|
10
|
+
)
|
|
3
11
|
from .rl import (
|
|
4
12
|
EvaluationConfig,
|
|
5
13
|
JudgeConfig,
|
|
@@ -28,14 +36,20 @@ __all__ = [
|
|
|
28
36
|
"AlgorithmConfig",
|
|
29
37
|
"ComputeConfig",
|
|
30
38
|
"EvaluationConfig",
|
|
39
|
+
"GEPAConfig",
|
|
31
40
|
"HyperparametersConfig",
|
|
32
41
|
"HyperparametersParallelism",
|
|
33
42
|
"JobConfig",
|
|
34
43
|
"JudgeConfig",
|
|
35
44
|
"JudgeOptionsConfig",
|
|
36
45
|
"LoraConfig",
|
|
46
|
+
"MIPROConfig",
|
|
47
|
+
"MessagePatternConfig",
|
|
37
48
|
"ModelConfig",
|
|
38
49
|
"PolicyConfig",
|
|
50
|
+
"PromptLearningConfig",
|
|
51
|
+
"PromptLearningPolicyConfig",
|
|
52
|
+
"PromptPatternConfig",
|
|
39
53
|
"RewardsConfig",
|
|
40
54
|
"RLConfig",
|
|
41
55
|
"RLServicesConfig",
|