synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
- examples/rl/configs/rl_from_base_qwen17.toml +1 -0
- examples/swe/task_app/hosted/inference/openai_client.py +0 -34
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/task_app.py +254 -36
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
- synth_ai/api/train/builders.py +90 -1
- synth_ai/api/train/cli.py +396 -21
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +15 -1
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +29 -0
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +85 -17
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +1 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/eval/core.py +13 -10
- synth_ai/cli/commands/filter/core.py +53 -17
- synth_ai/cli/commands/help/core.py +0 -1
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/train/judge_schemas.py +1 -0
- synth_ai/cli/commands/train/judge_validation.py +1 -0
- synth_ai/cli/commands/train/validation.py +0 -57
- synth_ai/cli/demo.py +35 -3
- synth_ai/cli/deploy/__init__.py +40 -25
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/task_app_deploy.py +1 -1
- synth_ai/cli/task_apps.py +53 -53
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/judge_schemas.py +1 -0
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/handlers.py +53 -4
- synth_ai/streaming/streamer.py +19 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +44 -8
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +17 -17
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +283 -1
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
- synth_ai/cli/commands/deploy/__init__.py +0 -23
- synth_ai/cli/commands/deploy/core.py +0 -614
- synth_ai/cli/commands/deploy/errors.py +0 -72
- synth_ai/cli/commands/deploy/validation.py +0 -11
- synth_ai/cli/deploy/core.py +0 -5
- synth_ai/cli/deploy/errors.py +0 -23
- synth_ai/cli/deploy/validation.py +0 -5
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""SDK-side validation for training configs - catch errors BEFORE sending to backend."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import click
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ConfigValidationError(Exception):
|
|
10
|
+
"""Raised when a training config is invalid."""
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def validate_prompt_learning_config(config_data: dict[str, Any], config_path: Path) -> None:
|
|
15
|
+
"""
|
|
16
|
+
Validate prompt learning config BEFORE sending to backend.
|
|
17
|
+
|
|
18
|
+
This catches common errors early with clear messages instead of cryptic backend errors.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
config_data: Parsed TOML/JSON config
|
|
22
|
+
config_path: Path to config file (for error messages)
|
|
23
|
+
|
|
24
|
+
Raises:
|
|
25
|
+
ConfigValidationError: If config is invalid
|
|
26
|
+
click.ClickException: If validation fails (for CLI)
|
|
27
|
+
"""
|
|
28
|
+
errors: list[str] = []
|
|
29
|
+
|
|
30
|
+
# Check for prompt_learning section
|
|
31
|
+
pl_section = config_data.get("prompt_learning")
|
|
32
|
+
if not pl_section:
|
|
33
|
+
errors.append(
|
|
34
|
+
"Missing [prompt_learning] section in config. "
|
|
35
|
+
"Expected: [prompt_learning] with algorithm, task_app_url, etc."
|
|
36
|
+
)
|
|
37
|
+
_raise_validation_errors(errors, config_path)
|
|
38
|
+
return
|
|
39
|
+
|
|
40
|
+
if not isinstance(pl_section, dict):
|
|
41
|
+
errors.append(
|
|
42
|
+
f"[prompt_learning] must be a table/dict, got {type(pl_section).__name__}"
|
|
43
|
+
)
|
|
44
|
+
_raise_validation_errors(errors, config_path)
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
# CRITICAL: Validate algorithm field
|
|
48
|
+
algorithm = pl_section.get("algorithm")
|
|
49
|
+
if not algorithm:
|
|
50
|
+
errors.append(
|
|
51
|
+
"Missing required field: prompt_learning.algorithm\n"
|
|
52
|
+
" Must be one of: 'gepa', 'mipro'\n"
|
|
53
|
+
" Example:\n"
|
|
54
|
+
" [prompt_learning]\n"
|
|
55
|
+
" algorithm = \"gepa\""
|
|
56
|
+
)
|
|
57
|
+
elif algorithm not in ("gepa", "mipro"):
|
|
58
|
+
errors.append(
|
|
59
|
+
f"Invalid algorithm: '{algorithm}'\n"
|
|
60
|
+
f" Must be one of: 'gepa', 'mipro' (Note: MIPRO not yet implemented)\n"
|
|
61
|
+
f" Got: '{algorithm}'"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Validate task_app_url
|
|
65
|
+
task_app_url = pl_section.get("task_app_url")
|
|
66
|
+
if not task_app_url:
|
|
67
|
+
errors.append(
|
|
68
|
+
"Missing required field: prompt_learning.task_app_url\n"
|
|
69
|
+
" Example:\n"
|
|
70
|
+
" task_app_url = \"http://127.0.0.1:8102\""
|
|
71
|
+
)
|
|
72
|
+
elif not isinstance(task_app_url, str):
|
|
73
|
+
errors.append(
|
|
74
|
+
f"task_app_url must be a string, got {type(task_app_url).__name__}"
|
|
75
|
+
)
|
|
76
|
+
elif not task_app_url.startswith(("http://", "https://")):
|
|
77
|
+
errors.append(
|
|
78
|
+
f"task_app_url must start with http:// or https://, got: '{task_app_url}'"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Validate initial_prompt if present
|
|
82
|
+
initial_prompt = pl_section.get("initial_prompt")
|
|
83
|
+
if initial_prompt:
|
|
84
|
+
if not isinstance(initial_prompt, dict):
|
|
85
|
+
errors.append(
|
|
86
|
+
f"prompt_learning.initial_prompt must be a table/dict, got {type(initial_prompt).__name__}"
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
# Validate messages array
|
|
90
|
+
messages = initial_prompt.get("messages")
|
|
91
|
+
if messages is not None:
|
|
92
|
+
if not isinstance(messages, list):
|
|
93
|
+
errors.append(
|
|
94
|
+
f"prompt_learning.initial_prompt.messages must be an array, got {type(messages).__name__}"
|
|
95
|
+
)
|
|
96
|
+
elif len(messages) == 0:
|
|
97
|
+
errors.append(
|
|
98
|
+
"prompt_learning.initial_prompt.messages is empty (must have at least one message)"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Validate policy config
|
|
102
|
+
policy = pl_section.get("policy")
|
|
103
|
+
if not policy or not isinstance(policy, dict):
|
|
104
|
+
errors.append("Missing [prompt_learning.policy] section or not a table")
|
|
105
|
+
else:
|
|
106
|
+
# Enforce inference_mode
|
|
107
|
+
mode = str(policy.get("inference_mode", "")).strip().lower()
|
|
108
|
+
if not mode:
|
|
109
|
+
errors.append("Missing required field: prompt_learning.policy.inference_mode (must be 'synth_hosted')")
|
|
110
|
+
elif mode != "synth_hosted":
|
|
111
|
+
errors.append("prompt_learning.policy.inference_mode must be 'synth_hosted' (bring_your_own unsupported)")
|
|
112
|
+
# Required fields for synth_hosted
|
|
113
|
+
provider = (policy.get("provider") or "").strip().lower()
|
|
114
|
+
model = (policy.get("model") or "").strip()
|
|
115
|
+
inference_url = (policy.get("inference_url") or "").strip()
|
|
116
|
+
if not provider:
|
|
117
|
+
errors.append("Missing required field: prompt_learning.policy.provider")
|
|
118
|
+
if not model:
|
|
119
|
+
errors.append("Missing required field: prompt_learning.policy.model")
|
|
120
|
+
if not inference_url:
|
|
121
|
+
errors.append("Missing required field: prompt_learning.policy.inference_url")
|
|
122
|
+
elif not isinstance(inference_url, str) or not inference_url.startswith(("http://", "https://")):
|
|
123
|
+
errors.append(f"policy.inference_url must start with http:// or https://, got: '{inference_url}'")
|
|
124
|
+
|
|
125
|
+
# Validate algorithm-specific config
|
|
126
|
+
if algorithm == "gepa":
|
|
127
|
+
gepa_config = pl_section.get("gepa")
|
|
128
|
+
if not gepa_config or not isinstance(gepa_config, dict):
|
|
129
|
+
errors.append("Missing [prompt_learning.gepa] section for GEPA algorithm")
|
|
130
|
+
else:
|
|
131
|
+
# Numeric sanity checks
|
|
132
|
+
def _pos_int(name: str) -> None:
|
|
133
|
+
val = gepa_config.get(name)
|
|
134
|
+
if val is not None:
|
|
135
|
+
try:
|
|
136
|
+
ival = int(val)
|
|
137
|
+
if ival <= 0:
|
|
138
|
+
errors.append(f"prompt_learning.gepa.{name} must be > 0")
|
|
139
|
+
except Exception:
|
|
140
|
+
errors.append(f"prompt_learning.gepa.{name} must be an integer")
|
|
141
|
+
for fld in ("initial_population_size", "num_generations", "children_per_generation", "max_concurrent_rollouts"):
|
|
142
|
+
_pos_int(fld)
|
|
143
|
+
# Budget cap
|
|
144
|
+
if "max_spend_usd" in gepa_config and gepa_config.get("max_spend_usd") is not None:
|
|
145
|
+
try:
|
|
146
|
+
f = float(gepa_config.get("max_spend_usd"))
|
|
147
|
+
if f <= 0:
|
|
148
|
+
errors.append("prompt_learning.gepa.max_spend_usd must be > 0 when provided")
|
|
149
|
+
except Exception:
|
|
150
|
+
errors.append("prompt_learning.gepa.max_spend_usd must be numeric")
|
|
151
|
+
|
|
152
|
+
elif algorithm == "mipro":
|
|
153
|
+
# MIPRO is not yet implemented in synth-ai
|
|
154
|
+
errors.append(
|
|
155
|
+
"MIPRO algorithm is not yet implemented in synth-ai.\n"
|
|
156
|
+
" Please use 'gepa' algorithm for prompt optimization.\n"
|
|
157
|
+
" MIPRO support is planned for a future release.\n"
|
|
158
|
+
" Example:\n"
|
|
159
|
+
" [prompt_learning]\n"
|
|
160
|
+
" algorithm = \"gepa\"\n"
|
|
161
|
+
" [prompt_learning.gepa]\n"
|
|
162
|
+
" # ... gepa configuration"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Raise all errors at once for better UX
|
|
166
|
+
if errors:
|
|
167
|
+
_raise_validation_errors(errors, config_path)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _raise_validation_errors(errors: list[str], config_path: Path) -> None:
|
|
171
|
+
"""Format and raise validation errors."""
|
|
172
|
+
error_msg = (
|
|
173
|
+
f"\n❌ Invalid prompt learning config: {config_path}\n\n"
|
|
174
|
+
f"Found {len(errors)} error(s):\n\n"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
for i, error in enumerate(errors, 1):
|
|
178
|
+
# Indent multi-line errors
|
|
179
|
+
indented_error = "\n ".join(error.split("\n"))
|
|
180
|
+
error_msg += f"{i}. {indented_error}\n\n"
|
|
181
|
+
|
|
182
|
+
error_msg += (
|
|
183
|
+
"📖 See example configs:\n"
|
|
184
|
+
" - examples/blog_posts/gepa/configs/banking77_gepa_local.toml\n"
|
|
185
|
+
" - examples/blog_posts/mipro/configs/banking77_mipro_local.toml\n"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
raise click.ClickException(error_msg)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def validate_rl_config(config_data: dict[str, Any], config_path: Path) -> None:
|
|
192
|
+
"""
|
|
193
|
+
Validate RL config BEFORE sending to backend.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
config_data: Parsed TOML/JSON config
|
|
197
|
+
config_path: Path to config file (for error messages)
|
|
198
|
+
|
|
199
|
+
Raises:
|
|
200
|
+
ConfigValidationError: If config is invalid
|
|
201
|
+
click.ClickException: If validation fails (for CLI)
|
|
202
|
+
"""
|
|
203
|
+
errors: list[str] = []
|
|
204
|
+
|
|
205
|
+
# Check for rl section
|
|
206
|
+
rl_section = config_data.get("rl") or config_data.get("online_rl")
|
|
207
|
+
if not rl_section:
|
|
208
|
+
errors.append(
|
|
209
|
+
"Missing [rl] or [online_rl] section in config"
|
|
210
|
+
)
|
|
211
|
+
_raise_validation_errors(errors, config_path)
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
# Validate algorithm
|
|
215
|
+
algorithm = rl_section.get("algorithm")
|
|
216
|
+
if not algorithm:
|
|
217
|
+
errors.append(
|
|
218
|
+
"Missing required field: rl.algorithm\n"
|
|
219
|
+
" Must be one of: 'grpo', 'ppo', etc."
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Validate task_url
|
|
223
|
+
task_url = rl_section.get("task_url")
|
|
224
|
+
if not task_url:
|
|
225
|
+
errors.append(
|
|
226
|
+
"Missing required field: rl.task_url"
|
|
227
|
+
)
|
|
228
|
+
elif not isinstance(task_url, str):
|
|
229
|
+
errors.append(
|
|
230
|
+
f"task_url must be a string, got {type(task_url).__name__}"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if errors:
|
|
234
|
+
_raise_validation_errors(errors, config_path)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def validate_sft_config(config_data: dict[str, Any], config_path: Path) -> None:
|
|
238
|
+
"""
|
|
239
|
+
Validate SFT config BEFORE sending to backend.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
config_data: Parsed TOML/JSON config
|
|
243
|
+
config_path: Path to config file (for error messages)
|
|
244
|
+
|
|
245
|
+
Raises:
|
|
246
|
+
ConfigValidationError: If config is invalid
|
|
247
|
+
click.ClickException: If validation fails (for CLI)
|
|
248
|
+
"""
|
|
249
|
+
errors: list[str] = []
|
|
250
|
+
|
|
251
|
+
# Check for sft section
|
|
252
|
+
sft_section = config_data.get("sft")
|
|
253
|
+
if not sft_section:
|
|
254
|
+
errors.append(
|
|
255
|
+
"Missing [sft] section in config"
|
|
256
|
+
)
|
|
257
|
+
_raise_validation_errors(errors, config_path)
|
|
258
|
+
return
|
|
259
|
+
|
|
260
|
+
# Validate model
|
|
261
|
+
model = sft_section.get("model")
|
|
262
|
+
if not model:
|
|
263
|
+
errors.append(
|
|
264
|
+
"Missing required field: sft.model"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
if errors:
|
|
268
|
+
_raise_validation_errors(errors, config_path)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
__all__ = [
|
|
272
|
+
"ConfigValidationError",
|
|
273
|
+
"validate_prompt_learning_config",
|
|
274
|
+
"validate_rl_config",
|
|
275
|
+
"validate_sft_config",
|
|
276
|
+
]
|
|
277
|
+
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Baseline file system for self-contained task evaluation.
|
|
2
|
+
|
|
3
|
+
This package provides abstractions for defining and executing baseline evaluations
|
|
4
|
+
without requiring deployed task apps. Supports both class-based and function-based
|
|
5
|
+
task runners with first-class train/val/test split support.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from synth_ai.baseline.config import (
|
|
11
|
+
BaselineConfig,
|
|
12
|
+
BaselineResults,
|
|
13
|
+
BaselineTaskRunner,
|
|
14
|
+
DataSplit,
|
|
15
|
+
TaskResult,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"BaselineConfig",
|
|
20
|
+
"BaselineTaskRunner",
|
|
21
|
+
"DataSplit",
|
|
22
|
+
"TaskResult",
|
|
23
|
+
"BaselineResults",
|
|
24
|
+
]
|
|
25
|
+
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
"""Core dataclasses for baseline configuration and results."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaselineTaskRunner:
|
|
11
|
+
"""
|
|
12
|
+
Base class for task runners.
|
|
13
|
+
|
|
14
|
+
Subclasses should implement `run_task` method for class-based approach,
|
|
15
|
+
or you can use standalone async functions for function-based approach.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
policy_config: Dict[str, Any],
|
|
21
|
+
env_config: Dict[str, Any],
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Initialize task runner with configuration.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
policy_config: Policy configuration (model, temperature, etc.)
|
|
28
|
+
env_config: Environment configuration (max_steps, difficulty, etc.)
|
|
29
|
+
"""
|
|
30
|
+
self.policy_config = policy_config
|
|
31
|
+
self.env_config = env_config
|
|
32
|
+
|
|
33
|
+
async def run_task(self, seed: int) -> TaskResult:
|
|
34
|
+
"""
|
|
35
|
+
Execute a single task instance.
|
|
36
|
+
|
|
37
|
+
This method is called for each seed in the selected split.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
seed: The seed/index for this task instance
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
TaskResult: Structured result containing success, rewards, metadata, trace
|
|
44
|
+
"""
|
|
45
|
+
raise NotImplementedError("Subclasses must implement run_task method")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class DataSplit:
|
|
50
|
+
"""Definition of a data split (train/val/test)."""
|
|
51
|
+
|
|
52
|
+
name: str # "train", "val", "test"
|
|
53
|
+
seeds: List[int] # Seed/index values for this split
|
|
54
|
+
metadata: Dict[str, Any] = field(default_factory=dict) # Optional metadata
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class TaskResult:
|
|
59
|
+
"""Result from a single task execution."""
|
|
60
|
+
|
|
61
|
+
# Required: Seed/index that was evaluated
|
|
62
|
+
seed: int
|
|
63
|
+
|
|
64
|
+
# Required: Did the task complete successfully?
|
|
65
|
+
success: bool
|
|
66
|
+
|
|
67
|
+
# Required: Outcome reward for the episode
|
|
68
|
+
outcome_reward: float
|
|
69
|
+
|
|
70
|
+
# Optional: Event rewards (step-level)
|
|
71
|
+
event_rewards: List[Dict[str, Any]] = field(default_factory=list)
|
|
72
|
+
|
|
73
|
+
# Optional: Total steps/turns taken
|
|
74
|
+
total_steps: int = 0
|
|
75
|
+
|
|
76
|
+
# Optional: Metadata (achievements, completion info, etc.)
|
|
77
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
78
|
+
|
|
79
|
+
# Optional: Error information if success=False
|
|
80
|
+
error: Optional[str] = None
|
|
81
|
+
|
|
82
|
+
# Optional: v3 trace (SessionTrace dict)
|
|
83
|
+
trace: Optional[Dict[str, Any]] = None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# Type alias for task runner (can be class or function)
|
|
87
|
+
TaskRunnerType = (
|
|
88
|
+
type[BaselineTaskRunner]
|
|
89
|
+
| Callable[[int, dict[str, Any], dict[str, Any]], Any] # Function signature
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Type alias for result aggregator (can be class or function)
|
|
93
|
+
AggregatorType = (
|
|
94
|
+
type[Any] # Class with aggregate() method
|
|
95
|
+
| Callable[[list[TaskResult]], dict[str, Any]] # Function signature
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class BaselineConfig:
|
|
101
|
+
"""Configuration for a baseline file.
|
|
102
|
+
|
|
103
|
+
A baseline file defines how to evaluate a task without requiring
|
|
104
|
+
a deployed task app. It provides self-contained evaluation logic
|
|
105
|
+
with first-class support for train/val/test splits.
|
|
106
|
+
|
|
107
|
+
Supports both class-based and function-based task runners:
|
|
108
|
+
- Class-based: Pass a class that inherits from BaselineTaskRunner
|
|
109
|
+
- Function-based: Pass an async function with signature:
|
|
110
|
+
async def task_runner(seed: int, policy_config: Dict[str, Any],
|
|
111
|
+
env_config: Dict[str, Any]) -> TaskResult
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
# Required: Unique identifier for this baseline config
|
|
115
|
+
baseline_id: str
|
|
116
|
+
|
|
117
|
+
# Required: Human-readable name
|
|
118
|
+
name: str
|
|
119
|
+
|
|
120
|
+
# Required: Task runner (class or function)
|
|
121
|
+
# Class-based: Pass a class inheriting from BaselineTaskRunner
|
|
122
|
+
# The class will be instantiated with policy_config and env_config,
|
|
123
|
+
# and run_task(seed) will be called for each seed.
|
|
124
|
+
# Function-based: Pass an async function with signature:
|
|
125
|
+
# async def task_runner(seed: int, policy_config: Dict[str, Any],
|
|
126
|
+
# env_config: Dict[str, Any]) -> TaskResult
|
|
127
|
+
task_runner: TaskRunnerType
|
|
128
|
+
|
|
129
|
+
# Required: Data splits (train/val/test)
|
|
130
|
+
splits: Dict[str, DataSplit]
|
|
131
|
+
|
|
132
|
+
# Optional: Description for documentation
|
|
133
|
+
description: str = ""
|
|
134
|
+
|
|
135
|
+
# Optional: Default policy configuration
|
|
136
|
+
default_policy_config: Dict[str, Any] = field(default_factory=dict)
|
|
137
|
+
|
|
138
|
+
# Optional: Default environment configuration
|
|
139
|
+
default_env_config: Dict[str, Any] = field(default_factory=dict)
|
|
140
|
+
|
|
141
|
+
# Optional: Metadata for filtering/organization
|
|
142
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
143
|
+
|
|
144
|
+
# Optional: Tags for filtering and discovery
|
|
145
|
+
tags: List[str] = field(default_factory=list)
|
|
146
|
+
|
|
147
|
+
# Optional: Custom result aggregator (class or function)
|
|
148
|
+
# Class-based: Pass a class with aggregate(results: List[TaskResult]) method
|
|
149
|
+
# The class will be instantiated and aggregate() called.
|
|
150
|
+
# Function-based: Pass a function with signature:
|
|
151
|
+
# def aggregate_results(results: List[TaskResult]) -> Dict[str, Any]
|
|
152
|
+
result_aggregator: Optional[AggregatorType] = None
|
|
153
|
+
|
|
154
|
+
# Optional: Path to this baseline file (set by discovery)
|
|
155
|
+
_source_path: Optional[Path] = None
|
|
156
|
+
|
|
157
|
+
def matches_tag(self, tag: str) -> bool:
|
|
158
|
+
"""Check if baseline matches a tag (case-insensitive)."""
|
|
159
|
+
return tag.lower() in [t.lower() for t in self.tags]
|
|
160
|
+
|
|
161
|
+
def matches_metadata(self, key: str, value: Any) -> bool:
|
|
162
|
+
"""Check if baseline metadata matches key-value pair."""
|
|
163
|
+
return self.metadata.get(key) == value
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass
|
|
167
|
+
class BaselineResults:
|
|
168
|
+
"""Aggregate results from a baseline evaluation."""
|
|
169
|
+
|
|
170
|
+
# Configuration that was used
|
|
171
|
+
config: BaselineConfig
|
|
172
|
+
|
|
173
|
+
# Split that was evaluated
|
|
174
|
+
split_name: str
|
|
175
|
+
|
|
176
|
+
# Per-seed results
|
|
177
|
+
results: List[TaskResult]
|
|
178
|
+
|
|
179
|
+
# Aggregate metrics
|
|
180
|
+
aggregate_metrics: Dict[str, Any]
|
|
181
|
+
|
|
182
|
+
# Execution metadata
|
|
183
|
+
execution_time_seconds: float
|
|
184
|
+
model_name: str
|
|
185
|
+
timestamp: str
|
|
186
|
+
|
|
187
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
188
|
+
"""Serialize to dictionary for JSON output."""
|
|
189
|
+
return {
|
|
190
|
+
"baseline_id": self.config.baseline_id,
|
|
191
|
+
"name": self.config.name,
|
|
192
|
+
"split": self.split_name,
|
|
193
|
+
"model": self.model_name,
|
|
194
|
+
"timestamp": self.timestamp,
|
|
195
|
+
"execution_time_seconds": self.execution_time_seconds,
|
|
196
|
+
"aggregate_metrics": self.aggregate_metrics,
|
|
197
|
+
"results": [
|
|
198
|
+
{
|
|
199
|
+
"seed": r.seed,
|
|
200
|
+
"success": r.success,
|
|
201
|
+
"outcome_reward": r.outcome_reward,
|
|
202
|
+
"total_steps": r.total_steps,
|
|
203
|
+
"metadata": r.metadata,
|
|
204
|
+
"error": r.error,
|
|
205
|
+
}
|
|
206
|
+
for r in self.results
|
|
207
|
+
],
|
|
208
|
+
}
|
|
209
|
+
|