synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
- examples/rl/configs/rl_from_base_qwen17.toml +1 -0
- examples/swe/task_app/hosted/inference/openai_client.py +0 -34
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/task_app.py +254 -36
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
- synth_ai/api/train/builders.py +90 -1
- synth_ai/api/train/cli.py +396 -21
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +15 -1
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +29 -0
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +85 -17
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +1 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/eval/core.py +13 -10
- synth_ai/cli/commands/filter/core.py +53 -17
- synth_ai/cli/commands/help/core.py +0 -1
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/train/judge_schemas.py +1 -0
- synth_ai/cli/commands/train/judge_validation.py +1 -0
- synth_ai/cli/commands/train/validation.py +0 -57
- synth_ai/cli/demo.py +35 -3
- synth_ai/cli/deploy/__init__.py +40 -25
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/task_app_deploy.py +1 -1
- synth_ai/cli/task_apps.py +53 -53
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/judge_schemas.py +1 -0
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/handlers.py +53 -4
- synth_ai/streaming/streamer.py +19 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +44 -8
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +17 -17
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +283 -1
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
- synth_ai/cli/commands/deploy/__init__.py +0 -23
- synth_ai/cli/commands/deploy/core.py +0 -614
- synth_ai/cli/commands/deploy/errors.py +0 -72
- synth_ai/cli/commands/deploy/validation.py +0 -11
- synth_ai/cli/deploy/core.py +0 -5
- synth_ai/cli/deploy/errors.py +0 -23
- synth_ai/cli/deploy/validation.py +0 -5
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
"""Prompt Learning configuration models for MIPRO and GEPA."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from collections.abc import Mapping
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from pydantic import Field, field_validator
|
|
10
|
+
|
|
11
|
+
from ..utils import load_toml
|
|
12
|
+
from .shared import ExtraModel
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class InferenceMode(str, Enum):
|
|
16
|
+
synth_hosted = "synth_hosted"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ProviderName(str, Enum):
|
|
20
|
+
openai = "openai"
|
|
21
|
+
groq = "groq"
|
|
22
|
+
google = "google"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PromptLearningPolicyConfig(ExtraModel):
|
|
26
|
+
"""Policy configuration for prompt learning (model, provider, etc.)."""
|
|
27
|
+
model: str
|
|
28
|
+
provider: ProviderName
|
|
29
|
+
inference_url: str
|
|
30
|
+
inference_mode: InferenceMode = InferenceMode.synth_hosted
|
|
31
|
+
temperature: float = 0.0
|
|
32
|
+
max_completion_tokens: int = 512
|
|
33
|
+
policy_name: str | None = None
|
|
34
|
+
|
|
35
|
+
@field_validator("inference_url")
|
|
36
|
+
@classmethod
|
|
37
|
+
def _normalize_inference_url(cls, v: str) -> str:
|
|
38
|
+
if not isinstance(v, str):
|
|
39
|
+
raise ValueError("inference_url must be a string")
|
|
40
|
+
v = v.strip()
|
|
41
|
+
if not v.startswith(("http://", "https://")):
|
|
42
|
+
raise ValueError("inference_url must start with http:// or https://")
|
|
43
|
+
return v
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MessagePatternConfig(ExtraModel):
|
|
47
|
+
"""Configuration for a single message pattern."""
|
|
48
|
+
role: str
|
|
49
|
+
pattern: str
|
|
50
|
+
order: int = 0
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class PromptPatternConfig(ExtraModel):
|
|
54
|
+
"""Initial prompt pattern configuration."""
|
|
55
|
+
id: str | None = None
|
|
56
|
+
name: str | None = None
|
|
57
|
+
messages: list[MessagePatternConfig] = []
|
|
58
|
+
wildcards: dict[str, str] = Field(default_factory=dict)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class MIPROConfig(ExtraModel):
|
|
62
|
+
"""MIPRO-specific configuration.
|
|
63
|
+
|
|
64
|
+
NOTE: MIPRO support is not yet implemented in synth-ai.
|
|
65
|
+
This configuration class exists for future compatibility.
|
|
66
|
+
Use GEPA algorithm for prompt optimization.
|
|
67
|
+
"""
|
|
68
|
+
num_iterations: int = 20
|
|
69
|
+
num_evaluations_per_iteration: int = 5
|
|
70
|
+
batch_size: int = 32
|
|
71
|
+
max_concurrent: int = 20
|
|
72
|
+
env_name: str = "banking77"
|
|
73
|
+
env_config: dict[str, Any] | None = None
|
|
74
|
+
meta_model: str = "gpt-4o-mini"
|
|
75
|
+
meta_model_provider: str = "openai"
|
|
76
|
+
meta_model_inference_url: str | None = None
|
|
77
|
+
few_shot_score_threshold: float = 0.8
|
|
78
|
+
results_file: str | None = None
|
|
79
|
+
max_wall_clock_seconds: float | None = None
|
|
80
|
+
max_total_tokens: int | None = None
|
|
81
|
+
|
|
82
|
+
# TPE configuration
|
|
83
|
+
tpe: dict[str, Any] | None = None
|
|
84
|
+
|
|
85
|
+
# Demo configuration
|
|
86
|
+
demo: dict[str, Any] | None = None
|
|
87
|
+
|
|
88
|
+
# Grounding configuration
|
|
89
|
+
grounding: dict[str, Any] | None = None
|
|
90
|
+
|
|
91
|
+
# Meta-update configuration
|
|
92
|
+
meta_update: dict[str, Any] | None = None
|
|
93
|
+
|
|
94
|
+
# Bootstrap seeds (for few-shot examples)
|
|
95
|
+
bootstrap_train_seeds: list[int] | None = None
|
|
96
|
+
|
|
97
|
+
# Online pool (for mini-batch evaluation)
|
|
98
|
+
online_pool: list[int] | None = None
|
|
99
|
+
|
|
100
|
+
# Test pool (held-out seeds)
|
|
101
|
+
test_pool: list[int] | None = None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# GEPA nested configs (mirroring RL structure)
|
|
105
|
+
class GEPARolloutConfig(ExtraModel):
|
|
106
|
+
"""GEPA rollout configuration (mirrors RL [rollout] section)."""
|
|
107
|
+
budget: int | None = None # Total rollout budget
|
|
108
|
+
max_concurrent: int = 20 # Maximum concurrent rollouts
|
|
109
|
+
minibatch_size: int = 8 # Minibatch size for evaluation
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class GEPAEvaluationConfig(ExtraModel):
|
|
113
|
+
"""GEPA evaluation configuration (mirrors RL [evaluation] section)."""
|
|
114
|
+
seeds: list[int] | None = None # Evaluation seeds (training set)
|
|
115
|
+
validation_seeds: list[int] | None = None # Validation seeds (held-out)
|
|
116
|
+
test_pool: list[int] | None = None # Test pool (final evaluation)
|
|
117
|
+
validation_pool: str | None = None # Pool name for validation (e.g., "validation")
|
|
118
|
+
validation_top_k: int | None = None # Top-K prompts to validate
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class GEPAMutationConfig(ExtraModel):
|
|
122
|
+
"""GEPA mutation configuration (LLM-guided mutation settings)."""
|
|
123
|
+
rate: float = 0.3 # Mutation rate
|
|
124
|
+
llm_model: str | None = None # Model for generating mutations
|
|
125
|
+
llm_provider: str = "groq" # Provider for mutation LLM
|
|
126
|
+
llm_inference_url: str | None = None # Custom inference URL
|
|
127
|
+
prompt: str | None = None # Custom mutation prompt
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class GEPAPopulationConfig(ExtraModel):
|
|
131
|
+
"""GEPA population configuration (evolution parameters)."""
|
|
132
|
+
initial_size: int = 20 # Initial population size
|
|
133
|
+
num_generations: int = 10 # Number of generations
|
|
134
|
+
children_per_generation: int = 5 # Children generated per generation
|
|
135
|
+
crossover_rate: float = 0.5 # Crossover rate
|
|
136
|
+
selection_pressure: float = 1.0 # Pareto selection pressure
|
|
137
|
+
patience_generations: int = 3 # Early stopping patience
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class GEPAArchiveConfig(ExtraModel):
|
|
141
|
+
"""GEPA archive configuration (Pareto archive settings)."""
|
|
142
|
+
size: int = 64 # Archive size
|
|
143
|
+
pareto_set_size: int = 64 # Pareto set size
|
|
144
|
+
pareto_eps: float = 1e-6 # Pareto epsilon
|
|
145
|
+
feedback_fraction: float = 0.5 # Fraction of archive for feedback
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class GEPATokenConfig(ExtraModel):
|
|
149
|
+
"""GEPA token and budget configuration."""
|
|
150
|
+
max_limit: int | None = None # Maximum tokens allowed in prompt
|
|
151
|
+
counting_model: str = "gpt-4" # Model for token counting
|
|
152
|
+
enforce_pattern_limit: bool = True # Enforce token limit on patterns
|
|
153
|
+
max_spend_usd: float | None = None # Maximum spend in USD
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class GEPAConfig(ExtraModel):
|
|
157
|
+
"""GEPA-specific configuration with nested subsections."""
|
|
158
|
+
# Top-level fields (for backwards compatibility)
|
|
159
|
+
env_name: str = "banking77"
|
|
160
|
+
env_config: dict[str, Any] | None = None
|
|
161
|
+
rng_seed: int | None = None
|
|
162
|
+
proposer_type: str = "dspy" # "dspy" or "synth"
|
|
163
|
+
|
|
164
|
+
# Nested subsections (preferred, mirrors RL structure)
|
|
165
|
+
rollout: GEPARolloutConfig | None = None
|
|
166
|
+
evaluation: GEPAEvaluationConfig | None = None
|
|
167
|
+
mutation: GEPAMutationConfig | None = None
|
|
168
|
+
population: GEPAPopulationConfig | None = None
|
|
169
|
+
archive: GEPAArchiveConfig | None = None
|
|
170
|
+
token: GEPATokenConfig | None = None
|
|
171
|
+
|
|
172
|
+
# Backwards compatibility: flat fields (deprecated, prefer nested)
|
|
173
|
+
# These will be flattened from nested configs if provided
|
|
174
|
+
rollout_budget: int | None = None
|
|
175
|
+
max_concurrent_rollouts: int | None = None
|
|
176
|
+
minibatch_size: int | None = None
|
|
177
|
+
evaluation_seeds: list[int] | None = None
|
|
178
|
+
validation_seeds: list[int] | None = None
|
|
179
|
+
test_pool: list[int] | None = None
|
|
180
|
+
validation_pool: str | None = None
|
|
181
|
+
validation_top_k: int | None = None
|
|
182
|
+
mutation_rate: float | None = None
|
|
183
|
+
mutation_llm_model: str | None = None
|
|
184
|
+
mutation_llm_provider: str | None = None
|
|
185
|
+
mutation_llm_inference_url: str | None = None
|
|
186
|
+
mutation_prompt: str | None = None
|
|
187
|
+
initial_population_size: int | None = None
|
|
188
|
+
num_generations: int | None = None
|
|
189
|
+
children_per_generation: int | None = None
|
|
190
|
+
crossover_rate: float | None = None
|
|
191
|
+
selection_pressure: float | None = None
|
|
192
|
+
patience_generations: int | None = None
|
|
193
|
+
archive_size: int | None = None
|
|
194
|
+
pareto_set_size: int | None = None
|
|
195
|
+
pareto_eps: float | None = None
|
|
196
|
+
feedback_fraction: float | None = None
|
|
197
|
+
max_token_limit: int | None = None
|
|
198
|
+
token_counting_model: str | None = None
|
|
199
|
+
enforce_pattern_token_limit: bool | None = None
|
|
200
|
+
max_spend_usd: float | None = None
|
|
201
|
+
|
|
202
|
+
def _get_rollout_budget(self) -> int | None:
|
|
203
|
+
"""Get rollout budget from nested or flat structure."""
|
|
204
|
+
if self.rollout and self.rollout.budget is not None:
|
|
205
|
+
return self.rollout.budget
|
|
206
|
+
return self.rollout_budget
|
|
207
|
+
|
|
208
|
+
def _get_max_concurrent_rollouts(self) -> int:
|
|
209
|
+
"""Get max concurrent rollouts from nested or flat structure."""
|
|
210
|
+
if self.rollout and self.rollout.max_concurrent is not None:
|
|
211
|
+
return self.rollout.max_concurrent
|
|
212
|
+
return self.max_concurrent_rollouts or 20
|
|
213
|
+
|
|
214
|
+
def _get_minibatch_size(self) -> int:
|
|
215
|
+
"""Get minibatch size from nested or flat structure."""
|
|
216
|
+
if self.rollout and self.rollout.minibatch_size is not None:
|
|
217
|
+
return self.rollout.minibatch_size
|
|
218
|
+
return self.minibatch_size or 8
|
|
219
|
+
|
|
220
|
+
def _get_evaluation_seeds(self) -> list[int] | None:
|
|
221
|
+
"""Get evaluation seeds from nested or flat structure."""
|
|
222
|
+
if self.evaluation and self.evaluation.seeds is not None:
|
|
223
|
+
return self.evaluation.seeds
|
|
224
|
+
return self.evaluation_seeds
|
|
225
|
+
|
|
226
|
+
def _get_validation_seeds(self) -> list[int] | None:
|
|
227
|
+
"""Get validation seeds from nested or flat structure."""
|
|
228
|
+
if self.evaluation and self.evaluation.validation_seeds is not None:
|
|
229
|
+
return self.evaluation.validation_seeds
|
|
230
|
+
return self.validation_seeds
|
|
231
|
+
|
|
232
|
+
def _get_test_pool(self) -> list[int] | None:
|
|
233
|
+
"""Get test pool from nested or flat structure."""
|
|
234
|
+
if self.evaluation and self.evaluation.test_pool is not None:
|
|
235
|
+
return self.evaluation.test_pool
|
|
236
|
+
return self.test_pool
|
|
237
|
+
|
|
238
|
+
def _get_mutation_rate(self) -> float:
|
|
239
|
+
"""Get mutation rate from nested or flat structure."""
|
|
240
|
+
if self.mutation and self.mutation.rate is not None:
|
|
241
|
+
return self.mutation.rate
|
|
242
|
+
return self.mutation_rate or 0.3
|
|
243
|
+
|
|
244
|
+
def _get_mutation_llm_model(self) -> str | None:
|
|
245
|
+
"""Get mutation LLM model from nested or flat structure."""
|
|
246
|
+
if self.mutation and self.mutation.llm_model is not None:
|
|
247
|
+
return self.mutation.llm_model
|
|
248
|
+
return self.mutation_llm_model
|
|
249
|
+
|
|
250
|
+
def _get_mutation_llm_provider(self) -> str:
|
|
251
|
+
"""Get mutation LLM provider from nested or flat structure."""
|
|
252
|
+
if self.mutation and self.mutation.llm_provider is not None:
|
|
253
|
+
return self.mutation.llm_provider
|
|
254
|
+
return self.mutation_llm_provider or "groq"
|
|
255
|
+
|
|
256
|
+
def _get_mutation_llm_inference_url(self) -> str | None:
|
|
257
|
+
"""Get mutation LLM inference URL from nested or flat structure."""
|
|
258
|
+
if self.mutation and self.mutation.llm_inference_url is not None:
|
|
259
|
+
return self.mutation.llm_inference_url
|
|
260
|
+
return self.mutation_llm_inference_url
|
|
261
|
+
|
|
262
|
+
def _get_mutation_prompt(self) -> str | None:
|
|
263
|
+
"""Get mutation prompt from nested or flat structure."""
|
|
264
|
+
if self.mutation and self.mutation.prompt is not None:
|
|
265
|
+
return self.mutation.prompt
|
|
266
|
+
return self.mutation_prompt
|
|
267
|
+
|
|
268
|
+
def _get_initial_population_size(self) -> int:
|
|
269
|
+
"""Get initial population size from nested or flat structure."""
|
|
270
|
+
if self.population and self.population.initial_size is not None:
|
|
271
|
+
return self.population.initial_size
|
|
272
|
+
return self.initial_population_size or 20
|
|
273
|
+
|
|
274
|
+
def _get_num_generations(self) -> int:
|
|
275
|
+
"""Get num generations from nested or flat structure."""
|
|
276
|
+
if self.population and self.population.num_generations is not None:
|
|
277
|
+
return self.population.num_generations
|
|
278
|
+
return self.num_generations or 10
|
|
279
|
+
|
|
280
|
+
def _get_children_per_generation(self) -> int:
|
|
281
|
+
"""Get children per generation from nested or flat structure."""
|
|
282
|
+
if self.population and self.population.children_per_generation is not None:
|
|
283
|
+
return self.population.children_per_generation
|
|
284
|
+
return self.children_per_generation or 5
|
|
285
|
+
|
|
286
|
+
def _get_crossover_rate(self) -> float:
|
|
287
|
+
"""Get crossover rate from nested or flat structure."""
|
|
288
|
+
if self.population and self.population.crossover_rate is not None:
|
|
289
|
+
return self.population.crossover_rate
|
|
290
|
+
return self.crossover_rate or 0.5
|
|
291
|
+
|
|
292
|
+
def _get_selection_pressure(self) -> float:
|
|
293
|
+
"""Get selection pressure from nested or flat structure."""
|
|
294
|
+
if self.population and self.population.selection_pressure is not None:
|
|
295
|
+
return self.population.selection_pressure
|
|
296
|
+
return self.selection_pressure or 1.0
|
|
297
|
+
|
|
298
|
+
def _get_patience_generations(self) -> int:
|
|
299
|
+
"""Get patience generations from nested or flat structure."""
|
|
300
|
+
if self.population and self.population.patience_generations is not None:
|
|
301
|
+
return self.population.patience_generations
|
|
302
|
+
return self.patience_generations or 3
|
|
303
|
+
|
|
304
|
+
def _get_archive_size(self) -> int:
|
|
305
|
+
"""Get archive size from nested or flat structure."""
|
|
306
|
+
if self.archive and self.archive.size is not None:
|
|
307
|
+
return self.archive.size
|
|
308
|
+
return self.archive_size or 64
|
|
309
|
+
|
|
310
|
+
def _get_pareto_set_size(self) -> int:
|
|
311
|
+
"""Get pareto set size from nested or flat structure."""
|
|
312
|
+
if self.archive and self.archive.pareto_set_size is not None:
|
|
313
|
+
return self.archive.pareto_set_size
|
|
314
|
+
return self.pareto_set_size or 64
|
|
315
|
+
|
|
316
|
+
def _get_pareto_eps(self) -> float:
|
|
317
|
+
"""Get pareto eps from nested or flat structure."""
|
|
318
|
+
if self.archive and self.archive.pareto_eps is not None:
|
|
319
|
+
return self.archive.pareto_eps
|
|
320
|
+
return self.pareto_eps or 1e-6
|
|
321
|
+
|
|
322
|
+
def _get_feedback_fraction(self) -> float:
|
|
323
|
+
"""Get feedback fraction from nested or flat structure."""
|
|
324
|
+
if self.archive and self.archive.feedback_fraction is not None:
|
|
325
|
+
return self.archive.feedback_fraction
|
|
326
|
+
return self.feedback_fraction or 0.5
|
|
327
|
+
|
|
328
|
+
def _get_max_token_limit(self) -> int | None:
|
|
329
|
+
"""Get max token limit from nested or flat structure."""
|
|
330
|
+
if self.token and self.token.max_limit is not None:
|
|
331
|
+
return self.token.max_limit
|
|
332
|
+
return self.max_token_limit
|
|
333
|
+
|
|
334
|
+
def _get_token_counting_model(self) -> str:
|
|
335
|
+
"""Get token counting model from nested or flat structure."""
|
|
336
|
+
if self.token and self.token.counting_model is not None:
|
|
337
|
+
return self.token.counting_model
|
|
338
|
+
return self.token_counting_model or "gpt-4"
|
|
339
|
+
|
|
340
|
+
def _get_enforce_pattern_token_limit(self) -> bool:
|
|
341
|
+
"""Get enforce pattern token limit from nested or flat structure."""
|
|
342
|
+
if self.token and self.token.enforce_pattern_limit is not None:
|
|
343
|
+
return self.token.enforce_pattern_limit
|
|
344
|
+
return self.enforce_pattern_token_limit if self.enforce_pattern_token_limit is not None else True
|
|
345
|
+
|
|
346
|
+
def _get_max_spend_usd(self) -> float | None:
|
|
347
|
+
"""Get max spend USD from nested or flat structure."""
|
|
348
|
+
if self.token and self.token.max_spend_usd is not None:
|
|
349
|
+
return self.token.max_spend_usd
|
|
350
|
+
return self.max_spend_usd
|
|
351
|
+
|
|
352
|
+
@classmethod
|
|
353
|
+
def from_mapping(cls, data: Mapping[str, Any]) -> GEPAConfig:
|
|
354
|
+
"""Load GEPA config from dict/TOML, handling both nested and flat structures."""
|
|
355
|
+
# Check for nested structure first
|
|
356
|
+
nested_data = {}
|
|
357
|
+
flat_data = {}
|
|
358
|
+
|
|
359
|
+
for key, value in data.items():
|
|
360
|
+
if key in ("rollout", "evaluation", "mutation", "population", "archive", "token"):
|
|
361
|
+
nested_data[key] = value
|
|
362
|
+
else:
|
|
363
|
+
flat_data[key] = value
|
|
364
|
+
|
|
365
|
+
# If we have nested data, create nested configs
|
|
366
|
+
if nested_data:
|
|
367
|
+
if "rollout" in nested_data:
|
|
368
|
+
nested_data["rollout"] = GEPARolloutConfig.model_validate(nested_data["rollout"])
|
|
369
|
+
if "evaluation" in nested_data:
|
|
370
|
+
nested_data["evaluation"] = GEPAEvaluationConfig.model_validate(nested_data["evaluation"])
|
|
371
|
+
if "mutation" in nested_data:
|
|
372
|
+
nested_data["mutation"] = GEPAMutationConfig.model_validate(nested_data["mutation"])
|
|
373
|
+
if "population" in nested_data:
|
|
374
|
+
nested_data["population"] = GEPAPopulationConfig.model_validate(nested_data["population"])
|
|
375
|
+
if "archive" in nested_data:
|
|
376
|
+
nested_data["archive"] = GEPAArchiveConfig.model_validate(nested_data["archive"])
|
|
377
|
+
if "token" in nested_data:
|
|
378
|
+
nested_data["token"] = GEPATokenConfig.model_validate(nested_data["token"])
|
|
379
|
+
|
|
380
|
+
# Merge nested and flat data
|
|
381
|
+
merged_data = {**flat_data, **nested_data}
|
|
382
|
+
return cls.model_validate(merged_data)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
class PromptLearningConfig(ExtraModel):
|
|
386
|
+
"""Top-level prompt learning configuration."""
|
|
387
|
+
algorithm: str # "mipro" or "gepa"
|
|
388
|
+
task_app_url: str
|
|
389
|
+
task_app_api_key: str | None = None
|
|
390
|
+
task_app_id: str | None = None
|
|
391
|
+
initial_prompt: PromptPatternConfig | None = None
|
|
392
|
+
policy: PromptLearningPolicyConfig | None = None
|
|
393
|
+
mipro: MIPROConfig | None = None
|
|
394
|
+
gepa: GEPAConfig | None = None
|
|
395
|
+
env_config: dict[str, Any] | None = None
|
|
396
|
+
|
|
397
|
+
def to_dict(self) -> dict[str, Any]:
|
|
398
|
+
"""Convert config to dictionary for API payload."""
|
|
399
|
+
result = self.model_dump(mode="python", exclude_none=True)
|
|
400
|
+
# Ensure prompt_learning section wraps everything
|
|
401
|
+
if "prompt_learning" not in result:
|
|
402
|
+
pl_data = dict(result.items())
|
|
403
|
+
result = {"prompt_learning": pl_data}
|
|
404
|
+
return result
|
|
405
|
+
|
|
406
|
+
@classmethod
|
|
407
|
+
def from_mapping(cls, data: Mapping[str, Any]) -> PromptLearningConfig:
|
|
408
|
+
"""Load prompt learning config from dict/TOML mapping."""
|
|
409
|
+
# Handle both [prompt_learning] section and flat structure
|
|
410
|
+
pl_data = data.get("prompt_learning", {})
|
|
411
|
+
if not pl_data:
|
|
412
|
+
# If no prompt_learning section, assume top-level is prompt_learning
|
|
413
|
+
pl_data = dict(data)
|
|
414
|
+
|
|
415
|
+
# Handle gepa config specially to support nested structure
|
|
416
|
+
if "gepa" in pl_data and isinstance(pl_data["gepa"], dict):
|
|
417
|
+
gepa_data = pl_data["gepa"]
|
|
418
|
+
pl_data["gepa"] = GEPAConfig.from_mapping(gepa_data)
|
|
419
|
+
|
|
420
|
+
return cls.model_validate(pl_data)
|
|
421
|
+
|
|
422
|
+
@classmethod
|
|
423
|
+
def from_path(cls, path: Path) -> PromptLearningConfig:
|
|
424
|
+
"""Load prompt learning config from TOML file."""
|
|
425
|
+
content = load_toml(path)
|
|
426
|
+
return cls.from_mapping(content)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
__all__ = [
|
|
430
|
+
"GEPAConfig",
|
|
431
|
+
"GEPARolloutConfig",
|
|
432
|
+
"GEPAEvaluationConfig",
|
|
433
|
+
"GEPAMutationConfig",
|
|
434
|
+
"GEPAPopulationConfig",
|
|
435
|
+
"GEPAArchiveConfig",
|
|
436
|
+
"GEPATokenConfig",
|
|
437
|
+
"MIPROConfig",
|
|
438
|
+
"MessagePatternConfig",
|
|
439
|
+
"PromptLearningConfig",
|
|
440
|
+
"PromptLearningPolicyConfig",
|
|
441
|
+
"PromptPatternConfig",
|
|
442
|
+
]
|
synth_ai/api/train/configs/rl.py
CHANGED
|
@@ -115,6 +115,33 @@ class JudgeConfig(ExtraModel):
|
|
|
115
115
|
options: JudgeOptionsConfig | None = None
|
|
116
116
|
|
|
117
117
|
|
|
118
|
+
class SmokeConfig(ExtraModel):
|
|
119
|
+
"""Configuration for local smoke testing (CLI only, ignored by trainer)."""
|
|
120
|
+
# Test parameters
|
|
121
|
+
task_url: str | None = None
|
|
122
|
+
env_name: str | None = None
|
|
123
|
+
policy_name: str | None = None
|
|
124
|
+
max_steps: int | None = None
|
|
125
|
+
policy: str | None = None # mock, gpt-5-nano, openai, groq
|
|
126
|
+
model: str | None = None
|
|
127
|
+
mock_backend: str | None = None # synthetic or openai
|
|
128
|
+
mock_port: int | None = None
|
|
129
|
+
return_trace: bool | None = None
|
|
130
|
+
use_mock: bool | None = None
|
|
131
|
+
|
|
132
|
+
# Task app auto-start configuration
|
|
133
|
+
task_app_name: str | None = None # Task app to serve (e.g., "grpo-crafter")
|
|
134
|
+
task_app_port: int | None = None # Port for task app (default: 8765)
|
|
135
|
+
task_app_env_file: str | None = None # Path to .env file for task app
|
|
136
|
+
task_app_force: bool | None = None # Use --force flag when serving
|
|
137
|
+
|
|
138
|
+
# sqld auto-start configuration
|
|
139
|
+
sqld_auto_start: bool | None = None # Auto-start sqld server
|
|
140
|
+
sqld_db_path: str | None = None # Database path (default: ./traces/local.db)
|
|
141
|
+
sqld_hrana_port: int | None = None # Hrana WebSocket port (default: 8080)
|
|
142
|
+
sqld_http_port: int | None = None # HTTP API port (default: 8081)
|
|
143
|
+
|
|
144
|
+
|
|
118
145
|
class RLConfig(ExtraModel):
|
|
119
146
|
algorithm: AlgorithmConfig
|
|
120
147
|
services: RLServicesConfig
|
|
@@ -131,6 +158,7 @@ class RLConfig(ExtraModel):
|
|
|
131
158
|
rubric: dict[str, Any] | None = None # DEPRECATED: use judge.reward_blend and judge.enabled instead
|
|
132
159
|
judge: JudgeConfig | None = None
|
|
133
160
|
tags: dict[str, Any] | None = None
|
|
161
|
+
smoke: SmokeConfig | None = None # CLI-only: local smoke testing config (ignored by trainer)
|
|
134
162
|
|
|
135
163
|
def to_dict(self) -> dict[str, Any]:
|
|
136
164
|
return self.model_dump(mode="python", exclude_none=True)
|
|
@@ -155,5 +183,6 @@ __all__ = [
|
|
|
155
183
|
"RLServicesConfig",
|
|
156
184
|
"RLTrainingConfig",
|
|
157
185
|
"RolloutConfig",
|
|
186
|
+
"SmokeConfig",
|
|
158
187
|
"WeightSyncConfig",
|
|
159
188
|
]
|
synth_ai/api/train/task_app.py
CHANGED
|
@@ -38,7 +38,7 @@ def _health_response_ok(resp: requests.Response | None) -> tuple[bool, str]:
|
|
|
38
38
|
return False, ""
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def check_task_app_health(base_url: str, api_key: str, *, timeout: float =
|
|
41
|
+
def check_task_app_health(base_url: str, api_key: str, *, timeout: float = 30.0) -> TaskAppHealth:
|
|
42
42
|
# Send ALL known environment keys so the server can authorize any valid one
|
|
43
43
|
import os
|
|
44
44
|
|