synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
synth_ai/__init__.py
CHANGED
|
@@ -12,15 +12,15 @@ except Exception:
|
|
|
12
12
|
# Silently fail if log filter can't be installed
|
|
13
13
|
pass
|
|
14
14
|
|
|
15
|
-
#
|
|
16
|
-
from synth_ai.sdk.
|
|
15
|
+
# Verifier schemas live under sdk/graphs/verifier_schemas.py
|
|
16
|
+
from synth_ai.sdk.graphs.verifier_schemas import (
|
|
17
17
|
CriterionScorePayload,
|
|
18
|
-
JudgeOptions,
|
|
19
|
-
JudgeScoreRequest,
|
|
20
|
-
JudgeScoreResponse,
|
|
21
|
-
JudgeTaskApp,
|
|
22
|
-
JudgeTracePayload,
|
|
23
18
|
ReviewPayload,
|
|
19
|
+
VerifierOptions,
|
|
20
|
+
VerifierScoreRequest,
|
|
21
|
+
VerifierScoreResponse,
|
|
22
|
+
VerifierTaskApp,
|
|
23
|
+
VerifierTracePayload,
|
|
24
24
|
)
|
|
25
25
|
|
|
26
26
|
try: # Prefer the installed package metadata when available
|
|
@@ -45,12 +45,12 @@ EventPartitionElement = RewardSignal = SystemTrace = TrainingQuestion = None #
|
|
|
45
45
|
trace_event_async = trace_event_sync = upload = None # type: ignore
|
|
46
46
|
|
|
47
47
|
__all__ = [
|
|
48
|
-
#
|
|
49
|
-
"
|
|
50
|
-
"
|
|
51
|
-
"
|
|
52
|
-
"
|
|
53
|
-
"
|
|
48
|
+
# Verifier API contracts
|
|
49
|
+
"VerifierScoreRequest",
|
|
50
|
+
"VerifierScoreResponse",
|
|
51
|
+
"VerifierOptions",
|
|
52
|
+
"VerifierTaskApp",
|
|
53
|
+
"VerifierTracePayload",
|
|
54
54
|
"ReviewPayload",
|
|
55
55
|
"CriterionScorePayload",
|
|
56
56
|
] # Explicitly define public API (v1 tracing omitted in minimal env)
|
synth_ai/cli/__init__.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
"""CLI subcommands for Synth AI.
|
|
2
2
|
|
|
3
|
-
This package hosts modular commands
|
|
4
|
-
|
|
5
|
-
pyproject entry point `synth_ai.cli:cli`.
|
|
3
|
+
This package hosts modular commands and exposes a top-level Click group
|
|
4
|
+
named `cli` compatible with the pyproject entry point `synth_ai.cli:cli`.
|
|
6
5
|
"""
|
|
7
6
|
|
|
8
7
|
import importlib
|
|
@@ -11,9 +10,6 @@ from collections.abc import Callable
|
|
|
11
10
|
from typing import Any
|
|
12
11
|
|
|
13
12
|
from synth_ai.cli.agents import claude_cmd, codex_cmd, opencode_cmd
|
|
14
|
-
from synth_ai.cli.commands.baseline import command as baseline_cmd
|
|
15
|
-
from synth_ai.cli.commands.baseline.list import list_command as baseline_list_cmd
|
|
16
|
-
from synth_ai.cli.commands.eval import command as eval_cmd
|
|
17
13
|
from synth_ai.cli.demos.demo import demo_cmd
|
|
18
14
|
from synth_ai.cli.deploy import deploy_cmd
|
|
19
15
|
from synth_ai.cli.infra.mcp import mcp_cmd
|
|
@@ -21,7 +17,6 @@ from synth_ai.cli.infra.modal_app import modal_app_cmd
|
|
|
21
17
|
from synth_ai.cli.infra.setup import setup_cmd
|
|
22
18
|
from synth_ai.cli.task_apps import task_app_cmd
|
|
23
19
|
from synth_ai.cli.training.train_cfg import train_cfg_cmd
|
|
24
|
-
from synth_ai.cli.usage import usage_cmd
|
|
25
20
|
|
|
26
21
|
# Load environment variables from a local .env if present (repo root)
|
|
27
22
|
try:
|
|
@@ -67,24 +62,20 @@ cli = _cli_module.cli # type: ignore[attr-defined]
|
|
|
67
62
|
|
|
68
63
|
# Register core commands implemented as standalone modules
|
|
69
64
|
|
|
70
|
-
cli.add_command(baseline_cmd, name="baseline")
|
|
71
|
-
baseline_cmd.add_command(baseline_list_cmd, name="list")
|
|
72
65
|
cli.add_command(claude_cmd, name="claude")
|
|
73
66
|
cli.add_command(codex_cmd, name="codex")
|
|
74
67
|
cli.add_command(demo_cmd, name="demo")
|
|
75
68
|
cli.add_command(deploy_cmd, name="deploy")
|
|
76
|
-
cli.add_command(eval_cmd, name="eval")
|
|
77
69
|
cli.add_command(mcp_cmd, name="mcp")
|
|
78
70
|
cli.add_command(modal_app_cmd, name="modal-app")
|
|
79
71
|
cli.add_command(opencode_cmd, name="opencode")
|
|
80
72
|
cli.add_command(setup_cmd, name="setup")
|
|
81
73
|
cli.add_command(task_app_cmd, name="task-app")
|
|
82
74
|
cli.add_command(train_cfg_cmd, name="train-cfg")
|
|
83
|
-
cli.add_command(usage_cmd, name="usage")
|
|
84
75
|
|
|
85
76
|
|
|
86
77
|
# Register optional subcommands packaged under synth_ai.cli.*
|
|
87
|
-
for _module_path in ("synth_ai.cli.commands.demo", "synth_ai.cli.
|
|
78
|
+
for _module_path in ("synth_ai.cli.commands.demo", "synth_ai.cli.infra.turso"):
|
|
88
79
|
module = _maybe_import(_module_path)
|
|
89
80
|
if not module:
|
|
90
81
|
continue
|
|
@@ -108,6 +99,9 @@ _maybe_call("synth_ai.cli.commands.help.core", "register", cli)
|
|
|
108
99
|
# Register scan command
|
|
109
100
|
_maybe_call("synth_ai.cli.commands.scan", "register", cli)
|
|
110
101
|
|
|
102
|
+
# Register eval command
|
|
103
|
+
_maybe_call("synth_ai.cli.commands.eval", "register", cli)
|
|
104
|
+
|
|
111
105
|
# Train CLI lives under synth_ai.sdk.api.train
|
|
112
106
|
_maybe_call("synth_ai.sdk.api.train", "register", cli)
|
|
113
107
|
|
|
@@ -136,6 +130,3 @@ _maybe_call("synth_ai.cli.utils.queue", "register", cli)
|
|
|
136
130
|
|
|
137
131
|
# Artifacts commands
|
|
138
132
|
_maybe_call("synth_ai.cli.commands.artifacts", "register", cli)
|
|
139
|
-
|
|
140
|
-
# Research Agent commands
|
|
141
|
-
_maybe_call("synth_ai.sdk.api.research_agent", "register", cli)
|
|
@@ -1,19 +1,10 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
from .errors import EvalCliError
|
|
4
|
-
from .validation import validate_eval_options
|
|
1
|
+
"""Eval command package."""
|
|
5
2
|
|
|
6
|
-
|
|
7
|
-
"command",
|
|
8
|
-
"get_command",
|
|
9
|
-
"EvalCliError",
|
|
10
|
-
"validate_eval_options",
|
|
11
|
-
]
|
|
3
|
+
from __future__ import annotations
|
|
12
4
|
|
|
5
|
+
def register(cli) -> None:
|
|
6
|
+
from synth_ai.cli.commands.eval.core import eval_command
|
|
7
|
+
cli.add_command(eval_command, name="eval")
|
|
13
8
|
|
|
14
|
-
def __getattr__(name: str):
|
|
15
|
-
if name in {"command", "get_command"}:
|
|
16
|
-
from .core import command, get_command
|
|
17
9
|
|
|
18
|
-
|
|
19
|
-
raise AttributeError(name)
|
|
10
|
+
__all__ = ["register"]
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
"""Eval command configuration loading and normalization.
|
|
2
|
+
|
|
3
|
+
This module handles loading and resolving evaluation configuration from:
|
|
4
|
+
- TOML config files (legacy eval format or prompt_learning format)
|
|
5
|
+
- Command-line arguments (override config values)
|
|
6
|
+
- Environment variables (for API keys, etc.)
|
|
7
|
+
|
|
8
|
+
**Config File Formats:**
|
|
9
|
+
|
|
10
|
+
1. **Legacy Eval Format:**
|
|
11
|
+
```toml
|
|
12
|
+
[eval]
|
|
13
|
+
app_id = "banking77"
|
|
14
|
+
url = "http://localhost:8103"
|
|
15
|
+
env_name = "banking77"
|
|
16
|
+
seeds = [0, 1, 2, 3, 4]
|
|
17
|
+
|
|
18
|
+
[eval.policy_config]
|
|
19
|
+
model = "gpt-4"
|
|
20
|
+
provider = "openai"
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
2. **Prompt Learning Format:**
|
|
24
|
+
```toml
|
|
25
|
+
[prompt_learning]
|
|
26
|
+
task_app_id = "banking77"
|
|
27
|
+
task_app_url = "http://localhost:8103"
|
|
28
|
+
|
|
29
|
+
[prompt_learning.gepa]
|
|
30
|
+
env_name = "banking77"
|
|
31
|
+
|
|
32
|
+
[prompt_learning.gepa.evaluation]
|
|
33
|
+
seeds = [0, 1, 2, 3, 4]
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
**See Also:**
|
|
37
|
+
- `synth_ai.cli.commands.eval.core.eval_command()`: CLI entry point
|
|
38
|
+
- `synth_ai.cli.commands.eval.runner.run_eval()`: Uses resolved config
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
from __future__ import annotations
|
|
42
|
+
|
|
43
|
+
from dataclasses import dataclass, field
|
|
44
|
+
from pathlib import Path
|
|
45
|
+
from typing import Any, Literal
|
|
46
|
+
|
|
47
|
+
from synth_ai.sdk.api.train.configs.prompt_learning import PromptLearningConfig
|
|
48
|
+
from synth_ai.sdk.api.train.utils import load_toml
|
|
49
|
+
from synth_ai.sdk.task.contracts import RolloutMode
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
SeedSet = Literal["seeds", "validation_seeds", "test_pool"]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass(slots=True)
|
|
56
|
+
class EvalRunConfig:
|
|
57
|
+
"""Configuration for evaluation runs.
|
|
58
|
+
|
|
59
|
+
This dataclass holds all configuration needed to execute an evaluation
|
|
60
|
+
against a task app. Values can come from TOML config files, CLI arguments,
|
|
61
|
+
or environment variables.
|
|
62
|
+
|
|
63
|
+
**Required Fields:**
|
|
64
|
+
app_id: Task app identifier
|
|
65
|
+
task_app_url: URL of running task app (or None to spawn locally)
|
|
66
|
+
seeds: List of seeds/indices to evaluate
|
|
67
|
+
|
|
68
|
+
**Optional Fields:**
|
|
69
|
+
env_name: Environment name (usually matches app_id)
|
|
70
|
+
policy_config: Model and provider configuration
|
|
71
|
+
backend_url: Backend URL for trace capture (enables backend mode)
|
|
72
|
+
concurrency: Number of parallel rollouts
|
|
73
|
+
return_trace: Whether to include traces in responses
|
|
74
|
+
|
|
75
|
+
**Example:**
|
|
76
|
+
```python
|
|
77
|
+
config = EvalRunConfig(
|
|
78
|
+
app_id="banking77",
|
|
79
|
+
task_app_url="http://localhost:8103",
|
|
80
|
+
backend_url="http://localhost:8000",
|
|
81
|
+
env_name="banking77",
|
|
82
|
+
seeds=[0, 1, 2, 3, 4],
|
|
83
|
+
policy_config={"model": "gpt-4", "provider": "openai"},
|
|
84
|
+
concurrency=5,
|
|
85
|
+
return_trace=True,
|
|
86
|
+
)
|
|
87
|
+
```
|
|
88
|
+
"""
|
|
89
|
+
app_id: str
|
|
90
|
+
task_app_url: str | None
|
|
91
|
+
task_app_api_key: str | None
|
|
92
|
+
env_name: str | None
|
|
93
|
+
env_config: dict[str, Any] = field(default_factory=dict)
|
|
94
|
+
policy_name: str | None = None
|
|
95
|
+
policy_config: dict[str, Any] = field(default_factory=dict)
|
|
96
|
+
seeds: list[int] = field(default_factory=list)
|
|
97
|
+
ops: list[str] = field(default_factory=list)
|
|
98
|
+
mode: RolloutMode = RolloutMode.EVAL
|
|
99
|
+
return_trace: bool = False
|
|
100
|
+
trace_format: str = "compact"
|
|
101
|
+
concurrency: int = 1
|
|
102
|
+
metadata: dict[str, str] = field(default_factory=dict)
|
|
103
|
+
output_txt: Path | None = None
|
|
104
|
+
output_json: Path | None = None
|
|
105
|
+
verifier_config: dict[str, Any] | None = None
|
|
106
|
+
backend_url: str | None = None
|
|
107
|
+
backend_api_key: str | None = None
|
|
108
|
+
wait: bool = False
|
|
109
|
+
poll_interval: float = 5.0
|
|
110
|
+
traces_dir: Path | None = None
|
|
111
|
+
config_path: Path | None = None
|
|
112
|
+
timeout: float | None = None
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def load_eval_toml(path: Path) -> dict[str, Any]:
|
|
116
|
+
if not path.exists():
|
|
117
|
+
raise FileNotFoundError(f"Eval config not found: {path}")
|
|
118
|
+
return load_toml(path)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _select_seed_pool(
|
|
122
|
+
*,
|
|
123
|
+
seeds: list[int] | None,
|
|
124
|
+
validation_seeds: list[int] | None,
|
|
125
|
+
test_pool: list[int] | None,
|
|
126
|
+
seed_set: SeedSet,
|
|
127
|
+
) -> list[int]:
|
|
128
|
+
if seed_set == "validation_seeds" and validation_seeds:
|
|
129
|
+
return validation_seeds
|
|
130
|
+
if seed_set == "test_pool" and test_pool:
|
|
131
|
+
return test_pool
|
|
132
|
+
if seeds:
|
|
133
|
+
return seeds
|
|
134
|
+
if validation_seeds:
|
|
135
|
+
return validation_seeds
|
|
136
|
+
if test_pool:
|
|
137
|
+
return test_pool
|
|
138
|
+
return []
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _from_prompt_learning(
|
|
142
|
+
raw: dict[str, Any],
|
|
143
|
+
*,
|
|
144
|
+
seed_set: SeedSet,
|
|
145
|
+
) -> EvalRunConfig:
|
|
146
|
+
pl_cfg = PromptLearningConfig.from_mapping(raw)
|
|
147
|
+
gepa = pl_cfg.gepa
|
|
148
|
+
mipro = pl_cfg.mipro
|
|
149
|
+
|
|
150
|
+
eval_cfg = gepa.evaluation if gepa else None
|
|
151
|
+
seeds = _select_seed_pool(
|
|
152
|
+
seeds=eval_cfg.seeds if eval_cfg else None,
|
|
153
|
+
validation_seeds=eval_cfg.validation_seeds if eval_cfg else None,
|
|
154
|
+
test_pool=eval_cfg.test_pool if eval_cfg else None,
|
|
155
|
+
seed_set=seed_set,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
env_name = None
|
|
159
|
+
env_config: dict[str, Any] = {}
|
|
160
|
+
if gepa:
|
|
161
|
+
env_name = gepa.env_name
|
|
162
|
+
env_config = dict(gepa.env_config or {})
|
|
163
|
+
elif mipro:
|
|
164
|
+
env_name = mipro.env_name
|
|
165
|
+
env_config = dict(mipro.env_config or {})
|
|
166
|
+
|
|
167
|
+
policy_cfg: dict[str, Any] = {}
|
|
168
|
+
if pl_cfg.policy:
|
|
169
|
+
policy_cfg = {
|
|
170
|
+
"model": pl_cfg.policy.model,
|
|
171
|
+
"provider": pl_cfg.policy.provider,
|
|
172
|
+
}
|
|
173
|
+
if pl_cfg.policy.inference_url:
|
|
174
|
+
policy_cfg["inference_url"] = pl_cfg.policy.inference_url
|
|
175
|
+
|
|
176
|
+
app_id = pl_cfg.task_app_id or (env_name or "")
|
|
177
|
+
verifier_cfg = None
|
|
178
|
+
if pl_cfg.verifier:
|
|
179
|
+
if isinstance(pl_cfg.verifier, dict):
|
|
180
|
+
verifier_cfg = dict(pl_cfg.verifier)
|
|
181
|
+
else:
|
|
182
|
+
verifier_cfg = pl_cfg.verifier.model_dump(mode="python")
|
|
183
|
+
|
|
184
|
+
return EvalRunConfig(
|
|
185
|
+
app_id=app_id,
|
|
186
|
+
task_app_url=pl_cfg.task_app_url,
|
|
187
|
+
task_app_api_key=pl_cfg.task_app_api_key,
|
|
188
|
+
env_name=env_name,
|
|
189
|
+
env_config=env_config,
|
|
190
|
+
policy_name=pl_cfg.policy.policy_name if pl_cfg.policy else None,
|
|
191
|
+
policy_config=policy_cfg,
|
|
192
|
+
seeds=seeds,
|
|
193
|
+
ops=[],
|
|
194
|
+
concurrency=(gepa.rollout.max_concurrent if gepa and gepa.rollout else 1),
|
|
195
|
+
verifier_config=verifier_cfg,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _from_legacy_eval(raw: dict[str, Any]) -> EvalRunConfig:
|
|
200
|
+
eval_section = raw.get("eval", {})
|
|
201
|
+
if not isinstance(eval_section, dict):
|
|
202
|
+
eval_section = {}
|
|
203
|
+
app_id = str(eval_section.get("app_id") or "").strip()
|
|
204
|
+
model = str(eval_section.get("model") or "").strip()
|
|
205
|
+
policy_cfg = dict(eval_section.get("policy_config") or {})
|
|
206
|
+
if model and "model" not in policy_cfg:
|
|
207
|
+
policy_cfg["model"] = model
|
|
208
|
+
if "provider" not in policy_cfg and eval_section.get("provider"):
|
|
209
|
+
policy_cfg["provider"] = eval_section.get("provider")
|
|
210
|
+
return EvalRunConfig(
|
|
211
|
+
app_id=app_id,
|
|
212
|
+
task_app_url=eval_section.get("url") or eval_section.get("task_app_url"),
|
|
213
|
+
task_app_api_key=eval_section.get("task_app_api_key"),
|
|
214
|
+
env_name=eval_section.get("env_name"),
|
|
215
|
+
env_config=dict(eval_section.get("env_config") or {}),
|
|
216
|
+
policy_name=eval_section.get("policy_name"),
|
|
217
|
+
policy_config=policy_cfg,
|
|
218
|
+
seeds=list(eval_section.get("seeds") or []),
|
|
219
|
+
ops=list(eval_section.get("ops") or []),
|
|
220
|
+
return_trace=bool(eval_section.get("return_trace", False)),
|
|
221
|
+
trace_format=str(eval_section.get("trace_format") or "compact"),
|
|
222
|
+
concurrency=int(eval_section.get("concurrency") or 1),
|
|
223
|
+
metadata=dict(eval_section.get("metadata") or {}),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def resolve_eval_config(
|
|
228
|
+
*,
|
|
229
|
+
config_path: Path | None,
|
|
230
|
+
cli_app_id: str | None,
|
|
231
|
+
cli_model: str | None,
|
|
232
|
+
cli_seeds: list[int] | None,
|
|
233
|
+
cli_url: str | None,
|
|
234
|
+
cli_env_file: str | None,
|
|
235
|
+
cli_ops: list[str] | None,
|
|
236
|
+
cli_return_trace: bool | None,
|
|
237
|
+
cli_concurrency: int | None,
|
|
238
|
+
cli_output_txt: Path | None,
|
|
239
|
+
cli_output_json: Path | None,
|
|
240
|
+
cli_backend_url: str | None,
|
|
241
|
+
cli_wait: bool,
|
|
242
|
+
cli_poll_interval: float | None,
|
|
243
|
+
cli_traces_dir: Path | None,
|
|
244
|
+
seed_set: SeedSet,
|
|
245
|
+
metadata: dict[str, str],
|
|
246
|
+
) -> EvalRunConfig:
|
|
247
|
+
"""Resolve evaluation configuration from multiple sources.
|
|
248
|
+
|
|
249
|
+
Loads configuration from TOML file (if provided) and merges with CLI arguments.
|
|
250
|
+
CLI arguments take precedence over config file values.
|
|
251
|
+
|
|
252
|
+
**Config File Formats:**
|
|
253
|
+
- Legacy eval format: `[eval]` section
|
|
254
|
+
- Prompt learning format: `[prompt_learning]` section
|
|
255
|
+
|
|
256
|
+
**Precedence Order:**
|
|
257
|
+
1. CLI arguments (highest priority)
|
|
258
|
+
2. Config file values
|
|
259
|
+
3. Default values
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
config_path: Path to TOML config file (optional)
|
|
263
|
+
cli_app_id: App ID from CLI (overrides config)
|
|
264
|
+
cli_model: Model name from CLI (overrides config)
|
|
265
|
+
cli_seeds: Seeds list from CLI (overrides config)
|
|
266
|
+
cli_url: Task app URL from CLI (overrides config)
|
|
267
|
+
cli_backend_url: Backend URL from CLI (overrides config)
|
|
268
|
+
cli_concurrency: Concurrency from CLI (overrides config)
|
|
269
|
+
seed_set: Which seed pool to use ("seeds", "validation_seeds", "test_pool")
|
|
270
|
+
metadata: Metadata key-value pairs for filtering
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Resolved EvalRunConfig with all values merged.
|
|
274
|
+
|
|
275
|
+
Raises:
|
|
276
|
+
FileNotFoundError: If config file is specified but doesn't exist.
|
|
277
|
+
|
|
278
|
+
Example:
|
|
279
|
+
```python
|
|
280
|
+
config = resolve_eval_config(
|
|
281
|
+
config_path=Path("banking77_eval.toml"),
|
|
282
|
+
cli_app_id="banking77",
|
|
283
|
+
cli_seeds=[0, 1, 2],
|
|
284
|
+
cli_url="http://localhost:8103",
|
|
285
|
+
seed_set="seeds",
|
|
286
|
+
metadata={},
|
|
287
|
+
)
|
|
288
|
+
```
|
|
289
|
+
"""
|
|
290
|
+
raw: dict[str, Any] = {}
|
|
291
|
+
if config_path is not None:
|
|
292
|
+
raw = load_eval_toml(config_path)
|
|
293
|
+
|
|
294
|
+
if raw and ("prompt_learning" in raw or raw.get("algorithm") in {"gepa", "mipro"}):
|
|
295
|
+
resolved = _from_prompt_learning(raw, seed_set=seed_set)
|
|
296
|
+
else:
|
|
297
|
+
resolved = _from_legacy_eval(raw)
|
|
298
|
+
|
|
299
|
+
if cli_app_id:
|
|
300
|
+
resolved.app_id = cli_app_id
|
|
301
|
+
if cli_url:
|
|
302
|
+
resolved.task_app_url = cli_url
|
|
303
|
+
if cli_seeds:
|
|
304
|
+
resolved.seeds = cli_seeds
|
|
305
|
+
if cli_ops:
|
|
306
|
+
resolved.ops = cli_ops
|
|
307
|
+
if cli_return_trace is not None:
|
|
308
|
+
resolved.return_trace = cli_return_trace
|
|
309
|
+
if cli_concurrency is not None:
|
|
310
|
+
resolved.concurrency = cli_concurrency
|
|
311
|
+
if cli_output_txt is not None:
|
|
312
|
+
resolved.output_txt = cli_output_txt
|
|
313
|
+
if cli_output_json is not None:
|
|
314
|
+
resolved.output_json = cli_output_json
|
|
315
|
+
if cli_backend_url:
|
|
316
|
+
resolved.backend_url = cli_backend_url
|
|
317
|
+
if cli_wait:
|
|
318
|
+
resolved.wait = True
|
|
319
|
+
if cli_poll_interval is not None:
|
|
320
|
+
resolved.poll_interval = cli_poll_interval
|
|
321
|
+
if cli_traces_dir is not None:
|
|
322
|
+
resolved.traces_dir = cli_traces_dir
|
|
323
|
+
|
|
324
|
+
if cli_model:
|
|
325
|
+
resolved.policy_config["model"] = cli_model
|
|
326
|
+
if metadata:
|
|
327
|
+
resolved.metadata = metadata
|
|
328
|
+
|
|
329
|
+
if cli_env_file:
|
|
330
|
+
# Store in metadata for logging; env loading handled in core.
|
|
331
|
+
resolved.metadata.setdefault("env_file", cli_env_file)
|
|
332
|
+
|
|
333
|
+
resolved.config_path = config_path
|
|
334
|
+
|
|
335
|
+
return resolved
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
__all__ = ["EvalRunConfig", "resolve_eval_config", "SeedSet"]
|