synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
"""First-class SDK API for prompt learning (MIPRO and GEPA).
|
|
2
2
|
|
|
3
|
+
**Status:** Alpha
|
|
4
|
+
|
|
5
|
+
Note: MIPRO is Experimental, GEPA is Alpha.
|
|
6
|
+
|
|
3
7
|
This module provides high-level abstractions for running prompt optimization jobs
|
|
4
8
|
both via CLI (`uvx synth-ai train`) and programmatically in Python scripts.
|
|
5
9
|
|
|
@@ -8,13 +12,17 @@ Example CLI usage:
|
|
|
8
12
|
|
|
9
13
|
Example SDK usage:
|
|
10
14
|
from synth_ai.sdk.api.train.prompt_learning import PromptLearningJob
|
|
11
|
-
|
|
12
|
-
job = PromptLearningJob.
|
|
15
|
+
|
|
16
|
+
job = PromptLearningJob.from_dict(config_dict, api_key="sk_live_...")
|
|
13
17
|
job.submit()
|
|
14
|
-
result = job.poll_until_complete()
|
|
15
|
-
|
|
18
|
+
result = job.poll_until_complete(progress=True) # Built-in progress printing
|
|
19
|
+
|
|
20
|
+
if result.succeeded:
|
|
21
|
+
print(f"Best score: {result.best_score}")
|
|
22
|
+
else:
|
|
23
|
+
print(f"Failed: {result.error}")
|
|
16
24
|
|
|
17
|
-
For domain-specific
|
|
25
|
+
For domain-specific verification, you can use **Verifier Graphs**. See `PromptLearningVerifierConfig`
|
|
18
26
|
in `synth_ai.sdk.api.train.configs.prompt_learning` for configuration details.
|
|
19
27
|
"""
|
|
20
28
|
|
|
@@ -22,38 +30,186 @@ from __future__ import annotations
|
|
|
22
30
|
|
|
23
31
|
import asyncio
|
|
24
32
|
import os
|
|
25
|
-
|
|
33
|
+
import time
|
|
34
|
+
from dataclasses import dataclass, field
|
|
35
|
+
from enum import Enum
|
|
26
36
|
from pathlib import Path
|
|
27
|
-
from typing import Any, Callable, Dict, Optional
|
|
37
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
28
38
|
|
|
29
39
|
from synth_ai.core.telemetry import log_info
|
|
30
40
|
|
|
31
|
-
|
|
41
|
+
|
|
42
|
+
class JobStatus(str, Enum):
|
|
43
|
+
"""Status of a prompt learning job."""
|
|
44
|
+
|
|
45
|
+
PENDING = "pending"
|
|
46
|
+
QUEUED = "queued"
|
|
47
|
+
RUNNING = "running"
|
|
48
|
+
SUCCEEDED = "succeeded"
|
|
49
|
+
FAILED = "failed"
|
|
50
|
+
CANCELLED = "cancelled"
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def from_string(cls, status: str) -> "JobStatus":
|
|
54
|
+
"""Convert string to JobStatus, defaulting to PENDING for unknown values."""
|
|
55
|
+
try:
|
|
56
|
+
return cls(status.lower())
|
|
57
|
+
except ValueError:
|
|
58
|
+
return cls.PENDING
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def is_terminal(self) -> bool:
|
|
62
|
+
"""Whether this status is terminal (job won't change further)."""
|
|
63
|
+
return self in (JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.CANCELLED)
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def is_success(self) -> bool:
|
|
67
|
+
"""Whether this status indicates success."""
|
|
68
|
+
return self == JobStatus.SUCCEEDED
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class PromptLearningResult:
|
|
73
|
+
"""Typed result from a prompt learning job.
|
|
74
|
+
|
|
75
|
+
Provides clean accessors for common fields instead of raw dict access.
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
>>> result = job.poll_until_complete()
|
|
79
|
+
>>> if result.succeeded:
|
|
80
|
+
... print(f"Best score: {result.best_score}")
|
|
81
|
+
... print(f"Best prompt: {result.best_prompt[:100]}...")
|
|
82
|
+
>>> else:
|
|
83
|
+
... print(f"Failed: {result.error}")
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
job_id: str
|
|
87
|
+
status: JobStatus
|
|
88
|
+
best_score: Optional[float] = None
|
|
89
|
+
best_prompt: Optional[str] = None
|
|
90
|
+
error: Optional[str] = None
|
|
91
|
+
raw: Dict[str, Any] = field(default_factory=dict)
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def from_response(cls, job_id: str, data: Dict[str, Any]) -> "PromptLearningResult":
|
|
95
|
+
"""Create result from API response dict."""
|
|
96
|
+
status_str = data.get("status", "pending")
|
|
97
|
+
status = JobStatus.from_string(status_str)
|
|
98
|
+
|
|
99
|
+
# Extract best score from various field names (backward compat)
|
|
100
|
+
best_score = (
|
|
101
|
+
data.get("best_score")
|
|
102
|
+
or data.get("best_reward")
|
|
103
|
+
or data.get("best_train_score")
|
|
104
|
+
or data.get("best_train_reward")
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return cls(
|
|
108
|
+
job_id=job_id,
|
|
109
|
+
status=status,
|
|
110
|
+
best_score=best_score,
|
|
111
|
+
best_prompt=data.get("best_prompt"),
|
|
112
|
+
error=data.get("error"),
|
|
113
|
+
raw=data,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def succeeded(self) -> bool:
|
|
118
|
+
"""Whether the job succeeded."""
|
|
119
|
+
return self.status.is_success
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def failed(self) -> bool:
|
|
123
|
+
"""Whether the job failed."""
|
|
124
|
+
return self.status == JobStatus.FAILED
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def is_terminal(self) -> bool:
|
|
128
|
+
"""Whether the job has reached a terminal state."""
|
|
129
|
+
return self.status.is_terminal
|
|
130
|
+
|
|
131
|
+
from .builders import (
|
|
132
|
+
PromptLearningBuildResult,
|
|
133
|
+
build_prompt_learning_payload,
|
|
134
|
+
build_prompt_learning_payload_from_mapping,
|
|
135
|
+
)
|
|
32
136
|
from .pollers import JobPoller, PollOutcome
|
|
33
|
-
from .
|
|
34
|
-
from .utils import ensure_api_base, http_post
|
|
137
|
+
from .local_api import check_local_api_health
|
|
138
|
+
from .utils import ensure_api_base, http_get, http_post
|
|
35
139
|
|
|
36
140
|
|
|
37
141
|
@dataclass
|
|
38
142
|
class PromptLearningJobConfig:
|
|
39
|
-
"""Configuration for a prompt learning job.
|
|
40
|
-
|
|
41
|
-
|
|
143
|
+
"""Configuration for a prompt learning job.
|
|
144
|
+
|
|
145
|
+
This dataclass holds all the configuration needed to submit and run
|
|
146
|
+
a prompt learning job (MIPRO or GEPA optimization).
|
|
147
|
+
|
|
148
|
+
Supports two modes:
|
|
149
|
+
1. **File-based**: Provide `config_path` pointing to a TOML file
|
|
150
|
+
2. **Programmatic**: Provide `config_dict` with the configuration directly
|
|
151
|
+
|
|
152
|
+
Both modes go through the same `PromptLearningConfig` Pydantic validation.
|
|
153
|
+
|
|
154
|
+
Attributes:
|
|
155
|
+
config_path: Path to the TOML configuration file. Mutually exclusive with config_dict.
|
|
156
|
+
config_dict: Dictionary with prompt learning configuration. Mutually exclusive with config_path.
|
|
157
|
+
Should have the same structure as the TOML file (with 'prompt_learning' section).
|
|
158
|
+
backend_url: Base URL of the Synth API backend (e.g., "https://api.usesynth.ai").
|
|
159
|
+
api_key: Synth API key for authentication.
|
|
160
|
+
task_app_api_key: API key for authenticating with the Local API.
|
|
161
|
+
allow_experimental: If True, allows use of experimental models.
|
|
162
|
+
overrides: Dictionary of config overrides.
|
|
163
|
+
|
|
164
|
+
Example (file-based):
|
|
165
|
+
>>> config = PromptLearningJobConfig(
|
|
166
|
+
... config_path=Path("my_config.toml"),
|
|
167
|
+
... backend_url="https://api.usesynth.ai",
|
|
168
|
+
... api_key="sk_live_...",
|
|
169
|
+
... )
|
|
170
|
+
|
|
171
|
+
Example (programmatic):
|
|
172
|
+
>>> config = PromptLearningJobConfig(
|
|
173
|
+
... config_dict={
|
|
174
|
+
... "prompt_learning": {
|
|
175
|
+
... "algorithm": "gepa",
|
|
176
|
+
... "task_app_url": "https://tunnel.example.com",
|
|
177
|
+
... "policy": {"model": "gpt-4o-mini", "provider": "openai"},
|
|
178
|
+
... "gepa": {...},
|
|
179
|
+
... }
|
|
180
|
+
... },
|
|
181
|
+
... backend_url="https://api.usesynth.ai",
|
|
182
|
+
... api_key="sk_live_...",
|
|
183
|
+
... )
|
|
184
|
+
"""
|
|
185
|
+
|
|
42
186
|
backend_url: str
|
|
43
187
|
api_key: str
|
|
188
|
+
config_path: Optional[Path] = None
|
|
189
|
+
config_dict: Optional[Dict[str, Any]] = None
|
|
44
190
|
task_app_api_key: Optional[str] = None
|
|
45
191
|
allow_experimental: Optional[bool] = None
|
|
46
192
|
overrides: Optional[Dict[str, Any]] = None
|
|
47
|
-
|
|
193
|
+
|
|
48
194
|
def __post_init__(self) -> None:
|
|
49
195
|
"""Validate configuration."""
|
|
50
|
-
|
|
196
|
+
# Must provide exactly one of config_path or config_dict
|
|
197
|
+
has_path = self.config_path is not None
|
|
198
|
+
has_dict = self.config_dict is not None
|
|
199
|
+
|
|
200
|
+
if has_path and has_dict:
|
|
201
|
+
raise ValueError("Provide either config_path OR config_dict, not both")
|
|
202
|
+
if not has_path and not has_dict:
|
|
203
|
+
raise ValueError("Either config_path or config_dict is required")
|
|
204
|
+
|
|
205
|
+
if has_path and not self.config_path.exists():
|
|
51
206
|
raise FileNotFoundError(f"Config file not found: {self.config_path}")
|
|
207
|
+
|
|
52
208
|
if not self.backend_url:
|
|
53
209
|
raise ValueError("backend_url is required")
|
|
54
210
|
if not self.api_key:
|
|
55
211
|
raise ValueError("api_key is required")
|
|
56
|
-
|
|
212
|
+
|
|
57
213
|
# Get task_app_api_key from environment if not provided
|
|
58
214
|
if not self.task_app_api_key:
|
|
59
215
|
self.task_app_api_key = os.environ.get("ENVIRONMENT_API_KEY")
|
|
@@ -184,9 +340,100 @@ class PromptLearningJob:
|
|
|
184
340
|
allow_experimental=allow_experimental,
|
|
185
341
|
overrides=overrides or {},
|
|
186
342
|
)
|
|
187
|
-
|
|
343
|
+
|
|
188
344
|
return cls(config)
|
|
189
|
-
|
|
345
|
+
|
|
346
|
+
@classmethod
|
|
347
|
+
def from_dict(
|
|
348
|
+
cls,
|
|
349
|
+
config_dict: Dict[str, Any],
|
|
350
|
+
backend_url: Optional[str] = None,
|
|
351
|
+
api_key: Optional[str] = None,
|
|
352
|
+
task_app_api_key: Optional[str] = None,
|
|
353
|
+
allow_experimental: Optional[bool] = None,
|
|
354
|
+
overrides: Optional[Dict[str, Any]] = None,
|
|
355
|
+
skip_health_check: bool = False,
|
|
356
|
+
) -> PromptLearningJob:
|
|
357
|
+
"""Create a job from a configuration dictionary (programmatic use).
|
|
358
|
+
|
|
359
|
+
This allows creating prompt learning jobs without a TOML file, enabling
|
|
360
|
+
programmatic use in notebooks, scripts, and applications.
|
|
361
|
+
|
|
362
|
+
The config_dict should have the same structure as a TOML file:
|
|
363
|
+
```python
|
|
364
|
+
{
|
|
365
|
+
"prompt_learning": {
|
|
366
|
+
"algorithm": "gepa",
|
|
367
|
+
"task_app_url": "https://...",
|
|
368
|
+
"policy": {"model": "gpt-4o-mini", "provider": "openai"},
|
|
369
|
+
"gepa": {...},
|
|
370
|
+
}
|
|
371
|
+
}
|
|
372
|
+
```
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
config_dict: Configuration dictionary with 'prompt_learning' section
|
|
376
|
+
backend_url: Backend API URL (defaults to env or production)
|
|
377
|
+
api_key: API key (defaults to SYNTH_API_KEY env var)
|
|
378
|
+
task_app_api_key: Task app API key (defaults to ENVIRONMENT_API_KEY env var)
|
|
379
|
+
allow_experimental: Allow experimental models
|
|
380
|
+
overrides: Config overrides
|
|
381
|
+
skip_health_check: If True, skip task app health check before submission
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
PromptLearningJob instance
|
|
385
|
+
|
|
386
|
+
Raises:
|
|
387
|
+
ValueError: If required config is missing or invalid
|
|
388
|
+
|
|
389
|
+
Example:
|
|
390
|
+
>>> job = PromptLearningJob.from_dict(
|
|
391
|
+
... config_dict={
|
|
392
|
+
... "prompt_learning": {
|
|
393
|
+
... "algorithm": "gepa",
|
|
394
|
+
... "task_app_url": "https://tunnel.example.com",
|
|
395
|
+
... "policy": {"model": "gpt-4o-mini", "provider": "openai"},
|
|
396
|
+
... "gepa": {
|
|
397
|
+
... "rollout": {"budget": 50, "max_concurrent": 5},
|
|
398
|
+
... "evaluation": {"train_seeds": [1, 2, 3], "val_seeds": [4, 5]},
|
|
399
|
+
... "population": {"num_generations": 2, "children_per_generation": 2},
|
|
400
|
+
... },
|
|
401
|
+
... }
|
|
402
|
+
... },
|
|
403
|
+
... api_key="sk_live_...",
|
|
404
|
+
... )
|
|
405
|
+
>>> job_id = job.submit()
|
|
406
|
+
"""
|
|
407
|
+
import os
|
|
408
|
+
|
|
409
|
+
from synth_ai.core.env import get_backend_from_env
|
|
410
|
+
|
|
411
|
+
# Resolve backend URL
|
|
412
|
+
if not backend_url:
|
|
413
|
+
backend_url = os.environ.get("BACKEND_BASE_URL", "").strip()
|
|
414
|
+
if not backend_url:
|
|
415
|
+
base, _ = get_backend_from_env()
|
|
416
|
+
backend_url = f"{base}/api" if not base.endswith("/api") else base
|
|
417
|
+
|
|
418
|
+
# Resolve API key
|
|
419
|
+
if not api_key:
|
|
420
|
+
api_key = os.environ.get("SYNTH_API_KEY")
|
|
421
|
+
if not api_key:
|
|
422
|
+
raise ValueError(
|
|
423
|
+
"api_key is required (provide explicitly or set SYNTH_API_KEY env var)"
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
config = PromptLearningJobConfig(
|
|
427
|
+
config_dict=config_dict,
|
|
428
|
+
backend_url=backend_url,
|
|
429
|
+
api_key=api_key,
|
|
430
|
+
task_app_api_key=task_app_api_key,
|
|
431
|
+
allow_experimental=allow_experimental,
|
|
432
|
+
overrides=overrides or {},
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
return cls(config, skip_health_check=skip_health_check)
|
|
436
|
+
|
|
190
437
|
@classmethod
|
|
191
438
|
def from_job_id(
|
|
192
439
|
cls,
|
|
@@ -223,33 +470,59 @@ class PromptLearningJob:
|
|
|
223
470
|
"api_key is required (provide explicitly or set SYNTH_API_KEY env var)"
|
|
224
471
|
)
|
|
225
472
|
|
|
226
|
-
# Create minimal config (we don't need the config
|
|
473
|
+
# Create minimal config (we don't need the config for resuming - use empty dict)
|
|
474
|
+
# The config_dict is never used when resuming since we have the job_id
|
|
227
475
|
config = PromptLearningJobConfig(
|
|
228
|
-
|
|
476
|
+
config_dict={"prompt_learning": {"_resumed": True}}, # Placeholder for resume mode
|
|
229
477
|
backend_url=backend_url,
|
|
230
478
|
api_key=api_key,
|
|
231
479
|
)
|
|
232
|
-
|
|
480
|
+
|
|
233
481
|
return cls(config, job_id=job_id)
|
|
234
482
|
|
|
235
483
|
def _build_payload(self) -> PromptLearningBuildResult:
|
|
236
|
-
"""Build the job payload from config.
|
|
484
|
+
"""Build the job payload from config.
|
|
485
|
+
|
|
486
|
+
Supports both file-based (config_path) and programmatic (config_dict) modes.
|
|
487
|
+
Both modes route through the same PromptLearningConfig Pydantic validation.
|
|
488
|
+
"""
|
|
237
489
|
if self._build_result is None:
|
|
238
|
-
if not self.config.config_path.exists() or self.config.config_path.name == "/dev/null":
|
|
239
|
-
raise RuntimeError(
|
|
240
|
-
"Cannot build payload: config_path is required for new jobs. "
|
|
241
|
-
"Use from_job_id() to resume an existing job."
|
|
242
|
-
)
|
|
243
|
-
|
|
244
490
|
overrides = self.config.overrides or {}
|
|
245
491
|
overrides["backend"] = self.config.backend_url
|
|
246
|
-
|
|
247
|
-
self.
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
492
|
+
# Pass task_app_api_key to builder via overrides
|
|
493
|
+
if self.config.task_app_api_key:
|
|
494
|
+
overrides["task_app_api_key"] = self.config.task_app_api_key
|
|
495
|
+
|
|
496
|
+
# Route to appropriate builder based on config mode
|
|
497
|
+
if self.config.config_dict is not None:
|
|
498
|
+
# Programmatic mode: use dict-based builder
|
|
499
|
+
self._build_result = build_prompt_learning_payload_from_mapping(
|
|
500
|
+
raw_config=self.config.config_dict,
|
|
501
|
+
task_url=None,
|
|
502
|
+
overrides=overrides,
|
|
503
|
+
allow_experimental=self.config.allow_experimental,
|
|
504
|
+
source_label="PromptLearningJob.from_dict",
|
|
505
|
+
)
|
|
506
|
+
elif self.config.config_path is not None:
|
|
507
|
+
# File-based mode: use path-based builder
|
|
508
|
+
if not self.config.config_path.exists():
|
|
509
|
+
raise RuntimeError(
|
|
510
|
+
f"Config file not found: {self.config.config_path}. "
|
|
511
|
+
"Use from_dict() for programmatic config or from_job_id() to resume."
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
self._build_result = build_prompt_learning_payload(
|
|
515
|
+
config_path=self.config.config_path,
|
|
516
|
+
task_url=None,
|
|
517
|
+
overrides=overrides,
|
|
518
|
+
allow_experimental=self.config.allow_experimental,
|
|
519
|
+
)
|
|
520
|
+
else:
|
|
521
|
+
raise RuntimeError(
|
|
522
|
+
"Cannot build payload: either config_path or config_dict is required. "
|
|
523
|
+
"Use from_config() for file-based config, from_dict() for programmatic config, "
|
|
524
|
+
"or from_job_id() to resume an existing job."
|
|
525
|
+
)
|
|
253
526
|
return self._build_result
|
|
254
527
|
|
|
255
528
|
def submit(self) -> str:
|
|
@@ -262,7 +535,11 @@ class PromptLearningJob:
|
|
|
262
535
|
RuntimeError: If job submission fails
|
|
263
536
|
ValueError: If task app health check fails
|
|
264
537
|
"""
|
|
265
|
-
|
|
538
|
+
# Log context based on config mode
|
|
539
|
+
if self.config.config_path is not None:
|
|
540
|
+
ctx: Dict[str, Any] = {"config_path": str(self.config.config_path)}
|
|
541
|
+
else:
|
|
542
|
+
ctx = {"config_mode": "programmatic"}
|
|
266
543
|
log_info("PromptLearningJob.submit invoked", ctx=ctx)
|
|
267
544
|
if self._job_id:
|
|
268
545
|
raise RuntimeError(f"Job already submitted: {self._job_id}")
|
|
@@ -271,7 +548,7 @@ class PromptLearningJob:
|
|
|
271
548
|
|
|
272
549
|
# Health check (skip if _skip_health_check is set - useful for tunnels with DNS delay)
|
|
273
550
|
if not self._skip_health_check:
|
|
274
|
-
health =
|
|
551
|
+
health = check_local_api_health(build.task_url, self.config.task_app_api_key or "")
|
|
275
552
|
if not health.ok:
|
|
276
553
|
raise ValueError(f"Task app health check failed: {health.detail}")
|
|
277
554
|
|
|
@@ -351,40 +628,92 @@ class PromptLearningJob:
|
|
|
351
628
|
*,
|
|
352
629
|
timeout: float = 3600.0,
|
|
353
630
|
interval: float = 5.0,
|
|
631
|
+
progress: bool = False,
|
|
354
632
|
on_status: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
355
|
-
) ->
|
|
633
|
+
) -> PromptLearningResult:
|
|
356
634
|
"""Poll job until it reaches a terminal state.
|
|
357
|
-
|
|
635
|
+
|
|
358
636
|
Args:
|
|
359
637
|
timeout: Maximum seconds to wait
|
|
360
638
|
interval: Seconds between poll attempts
|
|
361
|
-
|
|
362
|
-
|
|
639
|
+
progress: If True, print status updates during polling (useful for notebooks)
|
|
640
|
+
on_status: Optional callback called on each status update (for custom progress handling)
|
|
641
|
+
|
|
363
642
|
Returns:
|
|
364
|
-
|
|
365
|
-
|
|
643
|
+
PromptLearningResult with typed status, best_score, etc.
|
|
644
|
+
|
|
366
645
|
Raises:
|
|
367
646
|
RuntimeError: If job hasn't been submitted yet
|
|
368
647
|
TimeoutError: If timeout is exceeded
|
|
648
|
+
|
|
649
|
+
Example:
|
|
650
|
+
>>> result = job.poll_until_complete(progress=True)
|
|
651
|
+
[00:15] running | score: 0.72
|
|
652
|
+
[00:30] running | score: 0.78
|
|
653
|
+
[00:45] succeeded | score: 0.85
|
|
654
|
+
>>> result.succeeded
|
|
655
|
+
True
|
|
656
|
+
>>> result.best_score
|
|
657
|
+
0.85
|
|
369
658
|
"""
|
|
370
659
|
if not self._job_id:
|
|
371
660
|
raise RuntimeError("Job not yet submitted. Call submit() first.")
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
661
|
+
|
|
662
|
+
job_id = self._job_id
|
|
663
|
+
base_url = ensure_api_base(self.config.backend_url)
|
|
664
|
+
headers = {
|
|
665
|
+
"Authorization": f"Bearer {self.config.api_key}",
|
|
666
|
+
"Content-Type": "application/json",
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
start_time = time.time()
|
|
670
|
+
elapsed = 0.0
|
|
671
|
+
last_data: Dict[str, Any] = {}
|
|
672
|
+
|
|
673
|
+
while elapsed <= timeout:
|
|
674
|
+
try:
|
|
675
|
+
# Fetch job status
|
|
676
|
+
url = f"{base_url}/prompt-learning/online/jobs/{job_id}"
|
|
677
|
+
resp = http_get(url, headers=headers)
|
|
678
|
+
data = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
|
679
|
+
last_data = dict(data) if isinstance(data, dict) else {}
|
|
680
|
+
|
|
681
|
+
status = JobStatus.from_string(last_data.get("status", "pending"))
|
|
682
|
+
best_score = (
|
|
683
|
+
last_data.get("best_score")
|
|
684
|
+
or last_data.get("best_reward")
|
|
685
|
+
or last_data.get("best_train_score")
|
|
686
|
+
or last_data.get("best_train_reward")
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
# Progress output
|
|
690
|
+
if progress:
|
|
691
|
+
mins, secs = divmod(int(elapsed), 60)
|
|
692
|
+
score_str = f"score: {best_score:.2f}" if best_score is not None else "score: --"
|
|
693
|
+
print(f"[{mins:02d}:{secs:02d}] {status.value} | {score_str}")
|
|
694
|
+
|
|
695
|
+
# Callback for custom handling
|
|
696
|
+
if on_status:
|
|
697
|
+
on_status(last_data)
|
|
698
|
+
|
|
699
|
+
# Check terminal state
|
|
700
|
+
if status.is_terminal:
|
|
701
|
+
return PromptLearningResult.from_response(job_id, last_data)
|
|
702
|
+
|
|
703
|
+
except Exception as exc:
|
|
704
|
+
if progress:
|
|
705
|
+
print(f"[poll] error: {exc}")
|
|
706
|
+
log_info("poll request failed", ctx={"error": str(exc), "job_id": job_id})
|
|
707
|
+
|
|
708
|
+
time.sleep(interval)
|
|
709
|
+
elapsed = time.time() - start_time
|
|
710
|
+
|
|
711
|
+
# Timeout reached
|
|
712
|
+
if progress:
|
|
713
|
+
print(f"[poll] timeout after {timeout:.0f}s")
|
|
714
|
+
|
|
715
|
+
# Return with whatever data we have, status will indicate not complete
|
|
716
|
+
return PromptLearningResult.from_response(job_id, last_data)
|
|
388
717
|
|
|
389
718
|
def get_results(self) -> Dict[str, Any]:
|
|
390
719
|
"""Get job results (prompts, scores, etc.).
|
|
@@ -463,8 +792,9 @@ class PromptLearningJob:
|
|
|
463
792
|
|
|
464
793
|
|
|
465
794
|
__all__ = [
|
|
795
|
+
"JobStatus",
|
|
466
796
|
"PromptLearningJob",
|
|
467
797
|
"PromptLearningJobConfig",
|
|
468
798
|
"PromptLearningJobPoller",
|
|
799
|
+
"PromptLearningResult",
|
|
469
800
|
]
|
|
470
|
-
|
synth_ai/sdk/api/train/rl.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""First-class SDK API for reinforcement learning (RL/GSPO).
|
|
2
2
|
|
|
3
|
+
**Status:** Experimental
|
|
4
|
+
|
|
3
5
|
This module provides high-level abstractions for running RL training jobs
|
|
4
6
|
both via CLI (`uvx synth-ai train --type rl`) and programmatically in Python scripts.
|
|
5
7
|
|
|
@@ -32,14 +34,49 @@ from synth_ai.core.telemetry import log_info
|
|
|
32
34
|
|
|
33
35
|
from .builders import RLBuildResult, build_rl_payload
|
|
34
36
|
from .pollers import RLJobPoller
|
|
35
|
-
from .
|
|
37
|
+
from .local_api import check_local_api_health
|
|
36
38
|
from .utils import ensure_api_base, http_post
|
|
37
39
|
|
|
38
40
|
|
|
39
41
|
@dataclass
|
|
40
42
|
class RLJobConfig:
|
|
41
|
-
"""Configuration for an RL training job.
|
|
42
|
-
|
|
43
|
+
"""Configuration for an RL training job.
|
|
44
|
+
|
|
45
|
+
This dataclass holds all the configuration needed to submit and run
|
|
46
|
+
a reinforcement learning training job (GSPO, GRPO, PPO, etc.).
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
config_path: Path to the TOML configuration file that defines the
|
|
50
|
+
RL training task, including model settings, training hyperparameters,
|
|
51
|
+
reward configuration, and Local API URL.
|
|
52
|
+
backend_url: Base URL of the Synth API backend (e.g.,
|
|
53
|
+
"https://api.usesynth.ai"). Can also be set via BACKEND_BASE_URL
|
|
54
|
+
environment variable.
|
|
55
|
+
api_key: Synth API key for authentication. Can also be set via
|
|
56
|
+
SYNTH_API_KEY environment variable.
|
|
57
|
+
task_app_url: URL of the Local API that serves rollout environments.
|
|
58
|
+
Can be set via TASK_APP_URL env var if not provided.
|
|
59
|
+
(Alias: also known as "task app URL" in older documentation)
|
|
60
|
+
task_app_api_key: API key for authenticating with the Local API.
|
|
61
|
+
Defaults to ENVIRONMENT_API_KEY env var if not provided.
|
|
62
|
+
(Alias: also known as "task app API key" in older documentation)
|
|
63
|
+
allow_experimental: If True, allows use of experimental models and
|
|
64
|
+
features. Defaults to None (uses config file setting).
|
|
65
|
+
overrides: Dictionary of config overrides that take precedence over
|
|
66
|
+
values in the TOML file. Useful for programmatic customization.
|
|
67
|
+
idempotency_key: Optional key for idempotent job submission. If provided,
|
|
68
|
+
submitting the same key twice will return the existing job instead
|
|
69
|
+
of creating a new one.
|
|
70
|
+
|
|
71
|
+
Example:
|
|
72
|
+
>>> config = RLJobConfig(
|
|
73
|
+
... config_path=Path("rl_config.toml"),
|
|
74
|
+
... backend_url="https://api.usesynth.ai",
|
|
75
|
+
... api_key="sk_live_...",
|
|
76
|
+
... task_app_url="https://my-task-app.example.com",
|
|
77
|
+
... )
|
|
78
|
+
"""
|
|
79
|
+
|
|
43
80
|
config_path: Path
|
|
44
81
|
backend_url: str
|
|
45
82
|
api_key: str
|
|
@@ -282,7 +319,7 @@ class RLJob:
|
|
|
282
319
|
# Health check (skip if _skip_health_check is set - useful for tunnels with DNS delay)
|
|
283
320
|
if not self._skip_health_check:
|
|
284
321
|
task_app_key = self.config.task_app_api_key or ""
|
|
285
|
-
health =
|
|
322
|
+
health = check_local_api_health(build.task_url, task_app_key)
|
|
286
323
|
if not health.ok:
|
|
287
324
|
raise ValueError(f"Task app health check failed: {health.detail}")
|
|
288
325
|
|
|
@@ -439,4 +476,3 @@ __all__ = [
|
|
|
439
476
|
"RLJob",
|
|
440
477
|
"RLJobConfig",
|
|
441
478
|
]
|
|
442
|
-
|
synth_ai/sdk/api/train/sft.py
CHANGED
|
@@ -21,6 +21,11 @@ class TaskAppHealth:
|
|
|
21
21
|
detail: str | None = None
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
@dataclass(slots=True)
|
|
25
|
+
class LocalAPIHealth(TaskAppHealth):
|
|
26
|
+
"""Alias for TaskAppHealth with LocalAPI naming."""
|
|
27
|
+
|
|
28
|
+
|
|
24
29
|
def _resolve_hostname_with_explicit_resolvers(hostname: str) -> str:
|
|
25
30
|
"""
|
|
26
31
|
Resolve hostname using explicit resolvers (1.1.1.1, 8.8.8.8) first,
|
|
@@ -245,6 +250,19 @@ def check_task_app_health(base_url: str, api_key: str, *, timeout: float = 10.0,
|
|
|
245
250
|
)
|
|
246
251
|
|
|
247
252
|
|
|
253
|
+
def check_local_api_health(
|
|
254
|
+
base_url: str, api_key: str, *, timeout: float = 10.0, max_retries: int = 5
|
|
255
|
+
) -> LocalAPIHealth:
|
|
256
|
+
"""Alias for check_task_app_health with LocalAPI naming."""
|
|
257
|
+
health = check_task_app_health(base_url, api_key, timeout=timeout, max_retries=max_retries)
|
|
258
|
+
return LocalAPIHealth(
|
|
259
|
+
ok=health.ok,
|
|
260
|
+
health_status=health.health_status,
|
|
261
|
+
task_info_status=health.task_info_status,
|
|
262
|
+
detail=health.detail,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
248
266
|
@dataclass(slots=True)
|
|
249
267
|
class ModalSecret:
|
|
250
268
|
name: str
|
|
@@ -323,9 +341,11 @@ __all__ = [
|
|
|
323
341
|
"ModalApp",
|
|
324
342
|
"ModalSecret",
|
|
325
343
|
"check_task_app_health",
|
|
344
|
+
"check_local_api_health",
|
|
326
345
|
"format_modal_apps",
|
|
327
346
|
"format_modal_secrets",
|
|
328
347
|
"get_modal_secret_value",
|
|
329
348
|
"list_modal_apps",
|
|
330
349
|
"list_modal_secrets",
|
|
350
|
+
"LocalAPIHealth",
|
|
331
351
|
]
|