synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
synth_ai/sdk/api/train/cli.py
CHANGED
|
@@ -38,7 +38,7 @@ from synth_ai.sdk.streaming import (
|
|
|
38
38
|
)
|
|
39
39
|
|
|
40
40
|
from .builders import build_prompt_learning_payload, build_rl_payload, build_sft_payload
|
|
41
|
-
from .
|
|
41
|
+
from .local_api import check_local_api_health
|
|
42
42
|
from .graphgen import GraphGenJob
|
|
43
43
|
from .graphgen_models import load_graphgen_taskset
|
|
44
44
|
from .context_learning import ContextLearningJob
|
|
@@ -465,23 +465,23 @@ _logger.debug("[TRAIN_MODULE] Module synth_ai.sdk.api.train.cli imported")
|
|
|
465
465
|
@click.option(
|
|
466
466
|
"--type",
|
|
467
467
|
"train_type_override",
|
|
468
|
-
type=click.Choice(["prompt", "rl", "sft", "
|
|
468
|
+
type=click.Choice(["prompt", "rl", "sft", "graphgen", "context_learning"]),
|
|
469
469
|
default=None,
|
|
470
|
-
help="Explicitly set training type. Required for
|
|
470
|
+
help="Explicitly set training type. Required for GraphGen (uses JSON datasets).",
|
|
471
471
|
)
|
|
472
472
|
@click.option(
|
|
473
473
|
"--rollout-budget",
|
|
474
474
|
"rollout_budget",
|
|
475
475
|
type=int,
|
|
476
476
|
default=None,
|
|
477
|
-
help="Rollout budget for
|
|
477
|
+
help="Rollout budget for GraphGen optimization (default: 100)",
|
|
478
478
|
)
|
|
479
479
|
@click.option(
|
|
480
480
|
"--proposer-effort",
|
|
481
481
|
"proposer_effort",
|
|
482
482
|
type=click.Choice(["low", "medium", "high"]),
|
|
483
483
|
default=None,
|
|
484
|
-
help="Proposer effort level for
|
|
484
|
+
help="Proposer effort level for GraphGen (default: medium)",
|
|
485
485
|
)
|
|
486
486
|
def train_command(
|
|
487
487
|
cfg_path: Path | None,
|
|
@@ -507,7 +507,7 @@ def train_command(
|
|
|
507
507
|
proposer_effort: str | None,
|
|
508
508
|
) -> None:
|
|
509
509
|
|
|
510
|
-
"""Interactive launcher for RL / SFT / Prompt Learning /
|
|
510
|
+
"""Interactive launcher for RL / SFT / Prompt Learning / GraphGen / Context Learning jobs."""
|
|
511
511
|
import traceback
|
|
512
512
|
|
|
513
513
|
ctx: dict[str, Any] = {
|
|
@@ -544,18 +544,18 @@ def train_command(
|
|
|
544
544
|
load_dotenv(Path(env_file), override=True)
|
|
545
545
|
click.echo(f"[TRAIN_CMD] Loaded explicit .env: {env_file}", err=True)
|
|
546
546
|
|
|
547
|
-
# Handle
|
|
548
|
-
if train_type_override == "
|
|
549
|
-
# For
|
|
547
|
+
# Handle GraphGen specially - it uses JSON datasets, not TOML configs
|
|
548
|
+
if train_type_override == "graphgen":
|
|
549
|
+
# For GraphGen, dataset_path is required and cfg_path is ignored
|
|
550
550
|
if not dataset_path:
|
|
551
551
|
raise click.ClickException(
|
|
552
|
-
"
|
|
553
|
-
"Usage: synth-ai train --type
|
|
552
|
+
"GraphGen requires --dataset flag with path to JSON dataset file.\n"
|
|
553
|
+
"Usage: synth-ai train --type graphgen --dataset my_tasks.json"
|
|
554
554
|
)
|
|
555
|
-
train_type =
|
|
556
|
-
click.echo(f"[TRAIN_CMD]
|
|
555
|
+
train_type = train_type_override
|
|
556
|
+
click.echo(f"[TRAIN_CMD] GraphGen mode: using dataset {dataset_path}", err=True)
|
|
557
557
|
else:
|
|
558
|
-
# Non-
|
|
558
|
+
# Non-GraphGen: use TOML config
|
|
559
559
|
if not cfg_path:
|
|
560
560
|
available_cfgs = find_train_cfgs_in_cwd()
|
|
561
561
|
if len(available_cfgs) == 1:
|
|
@@ -614,8 +614,8 @@ def train_command(
|
|
|
614
614
|
if backend_base_url_env:
|
|
615
615
|
click.echo(f" (from BACKEND_BASE_URL={backend_base_url_env})")
|
|
616
616
|
|
|
617
|
-
# Skip TOML-based validation for
|
|
618
|
-
if train_type != "
|
|
617
|
+
# Skip TOML-based validation for GraphGen (uses JSON datasets)
|
|
618
|
+
if train_type != "graphgen" and cfg_path:
|
|
619
619
|
_validate_openai_key_if_provider_is_openai(cfg_path)
|
|
620
620
|
|
|
621
621
|
match train_type:
|
|
@@ -681,12 +681,12 @@ def train_command(
|
|
|
681
681
|
stream_format=stream_format,
|
|
682
682
|
examples_limit=examples_limit,
|
|
683
683
|
)
|
|
684
|
-
case "
|
|
684
|
+
case "graphgen":
|
|
685
685
|
if not dataset_path:
|
|
686
|
-
raise click.ClickException("
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
dataset_path=
|
|
686
|
+
raise click.ClickException("GraphGen requires a dataset path.")
|
|
687
|
+
graphgen_dataset_path = Path(dataset_path).expanduser().resolve()
|
|
688
|
+
handle_graphgen(
|
|
689
|
+
dataset_path=graphgen_dataset_path,
|
|
690
690
|
backend_base=backend_base,
|
|
691
691
|
synth_key=synth_api_key,
|
|
692
692
|
policy_model=model,
|
|
@@ -930,7 +930,7 @@ def handle_rl(
|
|
|
930
930
|
os.environ["ENVIRONMENT_API_KEY"] = env_key
|
|
931
931
|
|
|
932
932
|
click.echo("Performing task app health check…")
|
|
933
|
-
health =
|
|
933
|
+
health = check_local_api_health(build.task_url, env_key)
|
|
934
934
|
if not health.ok:
|
|
935
935
|
click.echo(f"Task app health check failed: {health.detail}")
|
|
936
936
|
raise click.ClickException("Aborting due to failing health check")
|
|
@@ -1169,7 +1169,7 @@ def handle_sft(
|
|
|
1169
1169
|
limited_path.parent.rmdir()
|
|
1170
1170
|
|
|
1171
1171
|
|
|
1172
|
-
def
|
|
1172
|
+
def handle_graphgen(
|
|
1173
1173
|
*,
|
|
1174
1174
|
dataset_path: Path,
|
|
1175
1175
|
backend_base: str,
|
|
@@ -1182,43 +1182,51 @@ def handle_adas(
|
|
|
1182
1182
|
poll_interval: float,
|
|
1183
1183
|
stream_format: str,
|
|
1184
1184
|
) -> None:
|
|
1185
|
-
"""Handle
|
|
1185
|
+
"""Handle GraphGen workflow optimization job creation and streaming.
|
|
1186
1186
|
|
|
1187
|
-
|
|
1187
|
+
GraphGen uses JSON dataset files and auto-generates task apps.
|
|
1188
1188
|
"""
|
|
1189
1189
|
ctx: dict[str, Any] = {
|
|
1190
1190
|
"dataset_path": str(dataset_path),
|
|
1191
1191
|
"backend_base": backend_base,
|
|
1192
1192
|
"poll": poll,
|
|
1193
1193
|
}
|
|
1194
|
-
log_info("
|
|
1194
|
+
log_info("handle_graphgen invoked", ctx=ctx)
|
|
1195
1195
|
|
|
1196
1196
|
# Load dataset
|
|
1197
|
-
click.echo(f"Loading
|
|
1197
|
+
click.echo(f"Loading GraphGen dataset from: {dataset_path}")
|
|
1198
1198
|
try:
|
|
1199
1199
|
dataset = load_graphgen_taskset(dataset_path)
|
|
1200
1200
|
except FileNotFoundError:
|
|
1201
1201
|
raise click.ClickException(f"Dataset file not found: {dataset_path}")
|
|
1202
1202
|
except ValueError as e:
|
|
1203
|
-
raise click.ClickException(f"Invalid
|
|
1203
|
+
raise click.ClickException(f"Invalid GraphGen dataset format: {e}")
|
|
1204
|
+
|
|
1205
|
+
problem_spec = None
|
|
1206
|
+
try:
|
|
1207
|
+
raw_dataset = json.loads(dataset_path.read_text())
|
|
1208
|
+
problem_spec = raw_dataset.get("problem_spec") or raw_dataset.get("initial_prompt")
|
|
1209
|
+
except Exception:
|
|
1210
|
+
problem_spec = None
|
|
1204
1211
|
|
|
1205
1212
|
click.echo(f"Dataset loaded: {dataset.metadata.name}")
|
|
1206
1213
|
click.echo(f" Tasks: {len(dataset.tasks)}")
|
|
1207
1214
|
click.echo(f" Gold outputs: {len(dataset.gold_outputs)}")
|
|
1208
|
-
click.echo(f"
|
|
1215
|
+
click.echo(f" Verifier mode: {dataset.verifier_config.mode}")
|
|
1209
1216
|
|
|
1210
|
-
# Create
|
|
1217
|
+
# Create GraphGen job
|
|
1211
1218
|
job = GraphGenJob.from_dataset(
|
|
1212
1219
|
dataset=dataset,
|
|
1213
1220
|
policy_model=policy_model or "gpt-4o-mini",
|
|
1214
1221
|
rollout_budget=rollout_budget or 100,
|
|
1215
1222
|
proposer_effort=proposer_effort or "medium", # type: ignore
|
|
1223
|
+
problem_spec=problem_spec,
|
|
1216
1224
|
backend_url=backend_base,
|
|
1217
1225
|
api_key=synth_key,
|
|
1218
1226
|
auto_start=True,
|
|
1219
1227
|
)
|
|
1220
1228
|
|
|
1221
|
-
click.echo("\n=== Submitting
|
|
1229
|
+
click.echo("\n=== Submitting GraphGen Job ===")
|
|
1222
1230
|
click.echo(f"Policy model: {job.config.policy_model}")
|
|
1223
1231
|
click.echo(f"Rollout budget: {job.config.rollout_budget}")
|
|
1224
1232
|
click.echo(f"Proposer effort: {job.config.proposer_effort}")
|
|
@@ -1229,7 +1237,7 @@ def handle_adas(
|
|
|
1229
1237
|
raise click.ClickException(str(e))
|
|
1230
1238
|
|
|
1231
1239
|
click.echo(f"\n✓ Job created:")
|
|
1232
|
-
click.echo(f"
|
|
1240
|
+
click.echo(f" GraphGen Job ID: {result.graphgen_job_id}")
|
|
1233
1241
|
click.echo(f" Status: {result.status}")
|
|
1234
1242
|
|
|
1235
1243
|
if not poll:
|
|
@@ -1979,7 +1987,7 @@ def handle_prompt_learning(
|
|
|
1979
1987
|
click.echo("Performing task app health check…")
|
|
1980
1988
|
click.echo(f"Task app URL: {build.task_url}")
|
|
1981
1989
|
click.echo("⏳ Checking /health endpoint (timeout: 10s)...")
|
|
1982
|
-
health =
|
|
1990
|
+
health = check_local_api_health(build.task_url, env_key, timeout=10.0)
|
|
1983
1991
|
if not health.ok:
|
|
1984
1992
|
click.echo(f"❌ Task app health check failed: {health.detail}")
|
|
1985
1993
|
click.echo(f" Health status: {health.health_status}")
|
|
@@ -6,12 +6,11 @@ from .prompt_learning import (
|
|
|
6
6
|
MIPROConfig,
|
|
7
7
|
PromptLearningConfig,
|
|
8
8
|
PromptLearningPolicyConfig,
|
|
9
|
+
PromptLearningVerifierConfig,
|
|
9
10
|
PromptPatternConfig,
|
|
10
11
|
)
|
|
11
12
|
from .rl import (
|
|
12
13
|
EvaluationConfig,
|
|
13
|
-
JudgeConfig,
|
|
14
|
-
JudgeOptionsConfig,
|
|
15
14
|
ModelConfig,
|
|
16
15
|
RewardsConfig,
|
|
17
16
|
RLConfig,
|
|
@@ -19,6 +18,8 @@ from .rl import (
|
|
|
19
18
|
RLTrainingConfig,
|
|
20
19
|
RolloutConfig,
|
|
21
20
|
RubricConfig,
|
|
21
|
+
VerifierConfig,
|
|
22
|
+
VerifierOptionsConfig,
|
|
22
23
|
WeightSyncConfig,
|
|
23
24
|
)
|
|
24
25
|
from .sft import (
|
|
@@ -40,8 +41,9 @@ __all__ = [
|
|
|
40
41
|
"HyperparametersConfig",
|
|
41
42
|
"HyperparametersParallelism",
|
|
42
43
|
"JobConfig",
|
|
43
|
-
"
|
|
44
|
-
"
|
|
44
|
+
"PromptLearningVerifierConfig",
|
|
45
|
+
"VerifierConfig",
|
|
46
|
+
"VerifierOptionsConfig",
|
|
45
47
|
"LoraConfig",
|
|
46
48
|
"MIPROConfig",
|
|
47
49
|
"MessagePatternConfig",
|
|
@@ -1,4 +1,40 @@
|
|
|
1
|
-
"""Prompt Learning configuration models for MIPRO and GEPA.
|
|
1
|
+
"""Prompt Learning configuration models for MIPRO and GEPA.
|
|
2
|
+
|
|
3
|
+
This module defines the configuration schema for prompt optimization jobs using:
|
|
4
|
+
- **GEPA**: Genetic Evolution of Prompt Architectures - evolutionary optimization
|
|
5
|
+
- **MIPRO**: Meta-learning with bootstrap phase and TPE optimization
|
|
6
|
+
|
|
7
|
+
Example TOML configuration (GEPA):
|
|
8
|
+
```toml
|
|
9
|
+
[prompt_learning]
|
|
10
|
+
algorithm = "gepa"
|
|
11
|
+
task_app_url = "https://your-tunnel.trycloudflare.com"
|
|
12
|
+
task_app_api_key = "$ENVIRONMENT_API_KEY"
|
|
13
|
+
|
|
14
|
+
[prompt_learning.policy]
|
|
15
|
+
model = "gpt-4o-mini"
|
|
16
|
+
provider = "openai"
|
|
17
|
+
|
|
18
|
+
[prompt_learning.gepa]
|
|
19
|
+
env_name = "banking77"
|
|
20
|
+
proposer_effort = "LOW"
|
|
21
|
+
|
|
22
|
+
[prompt_learning.gepa.rollout]
|
|
23
|
+
budget = 100
|
|
24
|
+
max_concurrent = 20
|
|
25
|
+
|
|
26
|
+
[prompt_learning.gepa.evaluation]
|
|
27
|
+
seeds = {start = 0, end = 50}
|
|
28
|
+
|
|
29
|
+
[prompt_learning.gepa.population]
|
|
30
|
+
num_generations = 10
|
|
31
|
+
children_per_generation = 5
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
See Also:
|
|
35
|
+
- Training reference: /training/gepa, /training/mipro
|
|
36
|
+
- Quickstart: /quickstart/prompt-optimization-gepa
|
|
37
|
+
"""
|
|
2
38
|
from __future__ import annotations
|
|
3
39
|
|
|
4
40
|
from collections.abc import Mapping, Sequence
|
|
@@ -156,7 +192,7 @@ class MIPROSeedConfig(ExtraModel):
|
|
|
156
192
|
return _parse_seeds(v) or []
|
|
157
193
|
|
|
158
194
|
|
|
159
|
-
class
|
|
195
|
+
class PromptLearningVerifierConfig(ExtraModel):
|
|
160
196
|
"""Verifier configuration shared by GEPA and MIPRO.
|
|
161
197
|
|
|
162
198
|
This configures LLM-based evaluation of agent trajectories during prompt optimization.
|
|
@@ -166,15 +202,13 @@ class PromptLearningJudgeConfig(ExtraModel):
|
|
|
166
202
|
enabled: Whether to enable verifier-based scoring.
|
|
167
203
|
reward_source: Source of the final reward for optimization.
|
|
168
204
|
- "task_app": Use only environment rewards from task app (default).
|
|
169
|
-
- "
|
|
205
|
+
- "verifier": Use only verifier quality scores.
|
|
170
206
|
- "fused": Weighted combination of environment and verifier rewards.
|
|
171
207
|
backend_base: Base URL for the verifier service (e.g. "https://api.usesynth.ai").
|
|
172
208
|
backend_api_key_env: Env var containing the Synth API key (default: "SYNTH_API_KEY").
|
|
173
209
|
backend_provider: Provider for the verifier model (e.g. "openai", "groq").
|
|
174
210
|
backend_model: Model used to execute the verifier rubric or graph (e.g. "gpt-4o-mini").
|
|
175
|
-
|
|
176
|
-
Use this to point to a specific, versioned verifier artifact.
|
|
177
|
-
backend_rubric_id: Legacy alias for synth_verifier_id.
|
|
211
|
+
verifier_graph_id: ID or name of a registered Verifier Graph on the backend.
|
|
178
212
|
backend_event_enabled: Whether to enable fine-grained event-level scoring.
|
|
179
213
|
backend_outcome_enabled: Whether to enable episode-level outcome scoring.
|
|
180
214
|
weight_env: Weight for environment rewards in "fused" mode (default: 1.0).
|
|
@@ -182,13 +216,12 @@ class PromptLearningJudgeConfig(ExtraModel):
|
|
|
182
216
|
weight_outcome: Weight for verifier outcome rewards in "fused" mode (default: 0.0).
|
|
183
217
|
"""
|
|
184
218
|
enabled: bool = False
|
|
185
|
-
reward_source: Literal["task_app", "
|
|
219
|
+
reward_source: Literal["task_app", "verifier", "fused"] = "task_app"
|
|
186
220
|
backend_base: str = ""
|
|
187
221
|
backend_api_key_env: str = "SYNTH_API_KEY"
|
|
188
222
|
backend_provider: str = ""
|
|
189
223
|
backend_model: str = ""
|
|
190
|
-
|
|
191
|
-
backend_rubric_id: str = "" # Legacy alias for synth_verifier_id
|
|
224
|
+
verifier_graph_id: str = ""
|
|
192
225
|
backend_event_enabled: bool = True
|
|
193
226
|
backend_outcome_enabled: bool = True
|
|
194
227
|
backend_options: Dict[str, Any] = Field(default_factory=dict)
|
|
@@ -201,21 +234,6 @@ class PromptLearningJudgeConfig(ExtraModel):
|
|
|
201
234
|
spec_max_tokens: int = 5000
|
|
202
235
|
spec_context: Optional[str] = None
|
|
203
236
|
|
|
204
|
-
@model_validator(mode="before")
|
|
205
|
-
@classmethod
|
|
206
|
-
def _sync_verifier_ids(cls, data: Any) -> Any:
|
|
207
|
-
"""Sync synth_verifier_id and backend_rubric_id."""
|
|
208
|
-
if isinstance(data, dict):
|
|
209
|
-
if not data.get("synth_verifier_id") and data.get("backend_rubric_id"):
|
|
210
|
-
data["synth_verifier_id"] = data["backend_rubric_id"]
|
|
211
|
-
elif not data.get("backend_rubric_id") and data.get("synth_verifier_id"):
|
|
212
|
-
data["backend_rubric_id"] = data["synth_verifier_id"]
|
|
213
|
-
return data
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
class PromptLearningVerifierConfig(PromptLearningJudgeConfig):
|
|
217
|
-
"""Alias for PromptLearningJudgeConfig with verifier terminology."""
|
|
218
|
-
|
|
219
237
|
|
|
220
238
|
class ProxyModelsConfig(ExtraModel):
|
|
221
239
|
"""Configuration for proxy usage on policy evaluations.
|
|
@@ -697,8 +715,8 @@ class MIPROConfig(ExtraModel):
|
|
|
697
715
|
# Meta-update configuration
|
|
698
716
|
meta_update: dict[str, Any] | None = None
|
|
699
717
|
|
|
700
|
-
#
|
|
701
|
-
|
|
718
|
+
# Verifier configuration (shared with GEPA)
|
|
719
|
+
verifier: PromptLearningVerifierConfig | dict[str, Any] | None = None
|
|
702
720
|
|
|
703
721
|
# Proxy models configuration (optional, can also be at top-level)
|
|
704
722
|
proxy_models: ProxyModelsConfig | dict[str, Any] | None = None
|
|
@@ -1165,7 +1183,7 @@ class GEPAConfig(ExtraModel):
|
|
|
1165
1183
|
population: GEPAPopulationConfig | None = None
|
|
1166
1184
|
archive: GEPAArchiveConfig | None = None
|
|
1167
1185
|
token: GEPATokenConfig | None = None
|
|
1168
|
-
|
|
1186
|
+
verifier: PromptLearningVerifierConfig | dict[str, Any] | None = None
|
|
1169
1187
|
proxy_models: ProxyModelsConfig | dict[str, Any] | None = None # Proxy models config (can be at top-level or gepa-specific)
|
|
1170
1188
|
adaptive_pool: AdaptivePoolConfig | dict[str, Any] | None = None # Adaptive pooling config
|
|
1171
1189
|
adaptive_batch: GEPAAdaptiveBatchConfig | dict[str, Any] | None = None # Adaptive batch config (GEPA only)
|
|
@@ -1407,7 +1425,7 @@ class GEPAConfig(ExtraModel):
|
|
|
1407
1425
|
flat_data = {}
|
|
1408
1426
|
|
|
1409
1427
|
for key, value in data.items():
|
|
1410
|
-
if key in ("rollout", "evaluation", "mutation", "population", "archive", "token", "modules", "proxy_models", "adaptive_pool", "adaptive_batch", "
|
|
1428
|
+
if key in ("rollout", "evaluation", "mutation", "population", "archive", "token", "modules", "proxy_models", "adaptive_pool", "adaptive_batch", "verifier"):
|
|
1411
1429
|
nested_data[key] = value
|
|
1412
1430
|
else:
|
|
1413
1431
|
flat_data[key] = value
|
|
@@ -1483,7 +1501,83 @@ class GEPAConfig(ExtraModel):
|
|
|
1483
1501
|
|
|
1484
1502
|
|
|
1485
1503
|
class PromptLearningConfig(ExtraModel):
|
|
1486
|
-
"""
|
|
1504
|
+
"""Root configuration for Prompt Learning jobs (GEPA and MIPRO).
|
|
1505
|
+
|
|
1506
|
+
This is the top-level config loaded from a TOML file. Use `PromptLearningConfig.from_path()`
|
|
1507
|
+
to load from a file, or `PromptLearningConfig.from_mapping()` to load from a dict.
|
|
1508
|
+
|
|
1509
|
+
Prompt learning optimizes prompts for a given task app and dataset using one of
|
|
1510
|
+
two algorithms:
|
|
1511
|
+
- **GEPA**: Genetic Evolution of Prompt Architectures - evolutionary optimization
|
|
1512
|
+
with crossover, mutation, and selection across generations
|
|
1513
|
+
- **MIPRO**: Meta-learning with bootstrap phase and Tree-structured Parzen Estimator
|
|
1514
|
+
(TPE) optimization for hyperparameter tuning
|
|
1515
|
+
|
|
1516
|
+
Example:
|
|
1517
|
+
```python
|
|
1518
|
+
from synth_ai.sdk.api.train.configs.prompt_learning import PromptLearningConfig
|
|
1519
|
+
|
|
1520
|
+
# Load from file
|
|
1521
|
+
config = PromptLearningConfig.from_path("prompt_learning.toml")
|
|
1522
|
+
|
|
1523
|
+
# Or from dict
|
|
1524
|
+
config = PromptLearningConfig.from_mapping({
|
|
1525
|
+
"algorithm": "gepa",
|
|
1526
|
+
"task_app_url": "https://your-tunnel.trycloudflare.com",
|
|
1527
|
+
"gepa": {
|
|
1528
|
+
"env_name": "banking77",
|
|
1529
|
+
"policy": {"model": "gpt-4o-mini", "provider": "openai"},
|
|
1530
|
+
"generations": 5,
|
|
1531
|
+
"population_size": 4,
|
|
1532
|
+
},
|
|
1533
|
+
})
|
|
1534
|
+
```
|
|
1535
|
+
|
|
1536
|
+
Attributes:
|
|
1537
|
+
algorithm: Optimization algorithm - "gepa" or "mipro".
|
|
1538
|
+
task_app_url: URL of your task app (typically a Cloudflare tunnel URL).
|
|
1539
|
+
task_app_api_key: API key for authenticating with the task app.
|
|
1540
|
+
Defaults to ENVIRONMENT_API_KEY env var.
|
|
1541
|
+
task_app_id: Optional identifier for the task app (for logging).
|
|
1542
|
+
initial_prompt: Initial prompt pattern to seed optimization.
|
|
1543
|
+
policy: Policy (LLM) configuration for rollouts.
|
|
1544
|
+
mipro: MIPRO-specific configuration (if algorithm="mipro").
|
|
1545
|
+
gepa: GEPA-specific configuration (if algorithm="gepa").
|
|
1546
|
+
verifier: Optional verifier configuration for LLM-based reward scoring.
|
|
1547
|
+
proxy_models: Proxy models configuration for cost-effective evaluation.
|
|
1548
|
+
env_config: Additional environment configuration passed to task app.
|
|
1549
|
+
free_tier: Enable free tier mode with cost-effective OSS models.
|
|
1550
|
+
|
|
1551
|
+
Returns:
|
|
1552
|
+
After training completes, you receive a result dict:
|
|
1553
|
+
```python
|
|
1554
|
+
{
|
|
1555
|
+
"status": "succeeded",
|
|
1556
|
+
"best_score": 0.92,
|
|
1557
|
+
"best_snapshot_id": "snap_abc123",
|
|
1558
|
+
"final_prompt": "You are a helpful assistant...",
|
|
1559
|
+
"metrics": {
|
|
1560
|
+
"generations_completed": 5,
|
|
1561
|
+
"total_rollouts": 200,
|
|
1562
|
+
"improvement": 0.15,
|
|
1563
|
+
},
|
|
1564
|
+
}
|
|
1565
|
+
```
|
|
1566
|
+
|
|
1567
|
+
Events:
|
|
1568
|
+
During training, you'll receive streaming events:
|
|
1569
|
+
- `prompt_learning.created` - Job created
|
|
1570
|
+
- `prompt_learning.running` - Training started
|
|
1571
|
+
- `prompt_learning.generation.started` - New generation began
|
|
1572
|
+
- `prompt_learning.candidate.evaluated` - Candidate prompt evaluated
|
|
1573
|
+
- `prompt_learning.generation.completed` - Generation finished with best score
|
|
1574
|
+
- `prompt_learning.frontier.updated` - Pareto frontier updated (new best found)
|
|
1575
|
+
- `prompt_learning.succeeded` / `prompt_learning.failed` - Terminal states
|
|
1576
|
+
|
|
1577
|
+
See Also:
|
|
1578
|
+
- Training reference: /training/gepa, /training/mipro
|
|
1579
|
+
- Quickstart: /quickstart/prompt-optimization-gepa
|
|
1580
|
+
"""
|
|
1487
1581
|
algorithm: str # "mipro" or "gepa"
|
|
1488
1582
|
task_app_url: str
|
|
1489
1583
|
task_app_api_key: str | None = None
|
|
@@ -1492,7 +1586,7 @@ class PromptLearningConfig(ExtraModel):
|
|
|
1492
1586
|
policy: PromptLearningPolicyConfig | None = None
|
|
1493
1587
|
mipro: MIPROConfig | None = None
|
|
1494
1588
|
gepa: GEPAConfig | None = None
|
|
1495
|
-
|
|
1589
|
+
verifier: PromptLearningVerifierConfig | dict[str, Any] | None = None
|
|
1496
1590
|
proxy_models: ProxyModelsConfig | dict[str, Any] | None = None # Proxy models config (can be at top-level or algorithm-specific)
|
|
1497
1591
|
env_config: dict[str, Any] | None = None
|
|
1498
1592
|
|
|
@@ -1665,8 +1759,8 @@ class PromptLearningConfig(ExtraModel):
|
|
|
1665
1759
|
mipro_data["proxy_models"] = ProxyModelsConfig.model_validate(mipro_data["proxy_models"])
|
|
1666
1760
|
# If proxy_models not specified, leave as None (defaults to disabled)
|
|
1667
1761
|
|
|
1668
|
-
if "
|
|
1669
|
-
pl_data["
|
|
1762
|
+
if "verifier" in pl_data and isinstance(pl_data["verifier"], dict):
|
|
1763
|
+
pl_data["verifier"] = PromptLearningVerifierConfig.model_validate(pl_data["verifier"])
|
|
1670
1764
|
|
|
1671
1765
|
return cls.model_validate(pl_data)
|
|
1672
1766
|
|
|
@@ -1696,7 +1790,7 @@ __all__ = [
|
|
|
1696
1790
|
"PromptLearningConfig",
|
|
1697
1791
|
"PromptLearningPolicyConfig",
|
|
1698
1792
|
"PromptPatternConfig",
|
|
1699
|
-
"
|
|
1793
|
+
"PromptLearningVerifierConfig",
|
|
1700
1794
|
"ProxyModelsConfig",
|
|
1701
1795
|
"AdaptivePoolConfig",
|
|
1702
1796
|
"AdaptiveCurriculumLevel",
|