synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -1,73 +1,290 @@
|
|
|
1
|
-
"""Research Agent
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
1
|
+
"""Research Agent SDK models and job helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import os
|
|
9
|
+
import tomllib
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OptimizationTool(str, Enum):
|
|
14
|
+
MIPRO = "mipro"
|
|
15
|
+
GEPA = "gepa"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ModelProvider(str, Enum):
|
|
19
|
+
OPENAI = "openai"
|
|
20
|
+
GROQ = "groq"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class PermittedModel:
|
|
25
|
+
model: str
|
|
26
|
+
provider: ModelProvider
|
|
27
|
+
|
|
28
|
+
def to_dict(self) -> dict[str, Any]:
|
|
29
|
+
return {"model": self.model, "provider": self.provider.value}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class PermittedModelsConfig:
|
|
34
|
+
models: list[PermittedModel] = field(default_factory=list)
|
|
35
|
+
default_temperature: float | None = None
|
|
36
|
+
|
|
37
|
+
def to_dict(self) -> dict[str, Any]:
|
|
38
|
+
data = {"models": [model.to_dict() for model in self.models]}
|
|
39
|
+
if self.default_temperature is not None:
|
|
40
|
+
data["default_temperature"] = self.default_temperature
|
|
41
|
+
return data
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class DatasetSource:
|
|
46
|
+
source_type: str
|
|
47
|
+
hf_repo_id: str | None = None
|
|
48
|
+
hf_split: str | None = None
|
|
49
|
+
description: str | None = None
|
|
50
|
+
file_ids: list[str] | None = None
|
|
51
|
+
inline_data: dict[str, str] | None = None
|
|
52
|
+
|
|
53
|
+
def to_dict(self) -> dict[str, Any]:
|
|
54
|
+
data: dict[str, Any] = {"source_type": self.source_type}
|
|
55
|
+
if self.hf_repo_id:
|
|
56
|
+
data["hf_repo_id"] = self.hf_repo_id
|
|
57
|
+
if self.hf_split:
|
|
58
|
+
data["hf_split"] = self.hf_split
|
|
59
|
+
if self.description:
|
|
60
|
+
data["description"] = self.description
|
|
61
|
+
if self.file_ids is not None:
|
|
62
|
+
data["file_ids"] = self.file_ids
|
|
63
|
+
if self.inline_data is not None:
|
|
64
|
+
data["inline_data"] = self.inline_data
|
|
65
|
+
return data
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class MIPROConfig:
|
|
70
|
+
meta_model: str = "llama-3.3-70b-versatile"
|
|
71
|
+
meta_provider: ModelProvider = ModelProvider.GROQ
|
|
72
|
+
num_trials: int = 10
|
|
73
|
+
proposer_effort: str | None = None
|
|
74
|
+
|
|
75
|
+
def to_dict(self) -> dict[str, Any]:
|
|
76
|
+
data: dict[str, Any] = {
|
|
77
|
+
"meta_model": self.meta_model,
|
|
78
|
+
"meta_provider": self.meta_provider.value,
|
|
79
|
+
"num_trials": self.num_trials,
|
|
80
|
+
}
|
|
81
|
+
if self.proposer_effort is not None:
|
|
82
|
+
data["proposer_effort"] = self.proposer_effort
|
|
83
|
+
return data
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass
|
|
87
|
+
class GEPAConfig:
|
|
88
|
+
mutation_model: str = "openai/gpt-oss-120b"
|
|
89
|
+
population_size: int = 20
|
|
90
|
+
proposer_type: str = "dspy"
|
|
91
|
+
spec_path: str | None = None
|
|
92
|
+
|
|
93
|
+
def to_dict(self) -> dict[str, Any]:
|
|
94
|
+
data: dict[str, Any] = {
|
|
95
|
+
"mutation_model": self.mutation_model,
|
|
96
|
+
"population_size": self.population_size,
|
|
97
|
+
"proposer_type": self.proposer_type,
|
|
98
|
+
}
|
|
99
|
+
if self.spec_path is not None:
|
|
100
|
+
data["spec_path"] = self.spec_path
|
|
101
|
+
return data
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@dataclass
|
|
105
|
+
class ResearchConfig:
|
|
106
|
+
task_description: str
|
|
107
|
+
tools: list[OptimizationTool] = field(default_factory=list)
|
|
108
|
+
datasets: list[DatasetSource] = field(default_factory=list)
|
|
109
|
+
primary_metric: str = "accuracy"
|
|
110
|
+
num_iterations: int = 10
|
|
111
|
+
mipro_config: MIPROConfig | None = None
|
|
112
|
+
gepa_config: GEPAConfig | None = None
|
|
113
|
+
permitted_models: PermittedModelsConfig | None = None
|
|
114
|
+
|
|
115
|
+
def to_dict(self) -> dict[str, Any]:
|
|
116
|
+
data: dict[str, Any] = {
|
|
117
|
+
"task_description": self.task_description,
|
|
118
|
+
"tools": [tool.value for tool in self.tools],
|
|
119
|
+
"primary_metric": self.primary_metric,
|
|
120
|
+
"num_iterations": self.num_iterations,
|
|
121
|
+
}
|
|
122
|
+
if self.datasets:
|
|
123
|
+
data["datasets"] = [ds.to_dict() for ds in self.datasets]
|
|
124
|
+
if self.mipro_config is not None:
|
|
125
|
+
data["mipro_config"] = self.mipro_config.to_dict()
|
|
126
|
+
if self.gepa_config is not None:
|
|
127
|
+
data["gepa_config"] = self.gepa_config.to_dict()
|
|
128
|
+
if self.permitted_models is not None:
|
|
129
|
+
data["permitted_models"] = self.permitted_models.to_dict()
|
|
130
|
+
return data
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@dataclass
|
|
134
|
+
class ResearchAgentJobConfig:
|
|
135
|
+
research: ResearchConfig
|
|
136
|
+
repo_url: str = ""
|
|
137
|
+
repo_branch: str | None = None
|
|
138
|
+
inline_files: dict[str, str] | None = None
|
|
139
|
+
backend_url: str = ""
|
|
140
|
+
api_key: str = ""
|
|
141
|
+
allow_missing_api_key: bool = False
|
|
142
|
+
backend: str | None = None
|
|
143
|
+
model: str | None = None
|
|
144
|
+
max_agent_spend_usd: float | None = None
|
|
145
|
+
max_synth_spend_usd: float | None = None
|
|
146
|
+
reasoning_effort: str | None = None
|
|
147
|
+
|
|
148
|
+
def __post_init__(self) -> None:
|
|
149
|
+
if not self.repo_url and not self.inline_files:
|
|
150
|
+
raise ValueError("Either repo_url or inline_files must be provided")
|
|
151
|
+
if not self.api_key:
|
|
152
|
+
self.api_key = os.getenv("SYNTH_API_KEY", "").strip()
|
|
153
|
+
if not self.api_key and not self.allow_missing_api_key:
|
|
154
|
+
raise ValueError("api_key is required")
|
|
155
|
+
if not self.backend_url:
|
|
156
|
+
self.backend_url = "https://api.usesynth.ai"
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def from_toml(cls, path: str | Path) -> "ResearchAgentJobConfig":
|
|
160
|
+
path = Path(path)
|
|
161
|
+
if not path.exists():
|
|
162
|
+
raise FileNotFoundError(path)
|
|
163
|
+
data = tomllib.loads(path.read_text(encoding="utf-8"))
|
|
164
|
+
if "research_agent" not in data:
|
|
165
|
+
raise ValueError("Config must have [research_agent] section")
|
|
166
|
+
section = data["research_agent"]
|
|
167
|
+
research_section = section.get("research")
|
|
168
|
+
if research_section is None:
|
|
169
|
+
raise ValueError("research_agent.research config is required")
|
|
170
|
+
|
|
171
|
+
tools = [OptimizationTool(tool) for tool in research_section.get("tools", [])]
|
|
172
|
+
datasets = [
|
|
24
173
|
DatasetSource(
|
|
25
|
-
source_type="
|
|
26
|
-
hf_repo_id="
|
|
174
|
+
source_type=ds.get("source_type", ""),
|
|
175
|
+
hf_repo_id=ds.get("hf_repo_id"),
|
|
176
|
+
hf_split=ds.get("hf_split"),
|
|
177
|
+
description=ds.get("description"),
|
|
178
|
+
file_ids=ds.get("file_ids"),
|
|
179
|
+
inline_data=ds.get("inline_data"),
|
|
180
|
+
)
|
|
181
|
+
for ds in research_section.get("datasets", [])
|
|
182
|
+
]
|
|
183
|
+
mipro_cfg = None
|
|
184
|
+
if research_section.get("mipro_config"):
|
|
185
|
+
cfg = research_section["mipro_config"]
|
|
186
|
+
mipro_cfg = MIPROConfig(
|
|
187
|
+
meta_model=cfg.get("meta_model", MIPROConfig.meta_model),
|
|
188
|
+
meta_provider=ModelProvider(cfg.get("meta_provider", ModelProvider.GROQ.value)),
|
|
189
|
+
num_trials=cfg.get("num_trials", MIPROConfig.num_trials),
|
|
190
|
+
proposer_effort=cfg.get("proposer_effort"),
|
|
27
191
|
)
|
|
28
|
-
],
|
|
29
|
-
)
|
|
30
192
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
193
|
+
research = ResearchConfig(
|
|
194
|
+
task_description=research_section.get("task_description", ""),
|
|
195
|
+
tools=tools,
|
|
196
|
+
datasets=datasets,
|
|
197
|
+
primary_metric=research_section.get("primary_metric", "accuracy"),
|
|
198
|
+
num_iterations=research_section.get("num_iterations", 10),
|
|
199
|
+
mipro_config=mipro_cfg,
|
|
200
|
+
)
|
|
37
201
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
""
|
|
202
|
+
return cls(
|
|
203
|
+
research=research,
|
|
204
|
+
repo_url=section.get("repo_url", "") or "",
|
|
205
|
+
repo_branch=section.get("repo_branch"),
|
|
206
|
+
backend=section.get("backend"),
|
|
207
|
+
model=section.get("model"),
|
|
208
|
+
max_agent_spend_usd=section.get("max_agent_spend_usd"),
|
|
209
|
+
max_synth_spend_usd=section.get("max_synth_spend_usd"),
|
|
210
|
+
reasoning_effort=section.get("reasoning_effort"),
|
|
211
|
+
backend_url=section.get("backend_url", ""),
|
|
212
|
+
api_key=section.get("api_key", ""),
|
|
213
|
+
allow_missing_api_key=True,
|
|
214
|
+
)
|
|
42
215
|
|
|
43
|
-
from __future__ import annotations
|
|
44
216
|
|
|
45
|
-
|
|
217
|
+
class ResearchAgentJob:
|
|
218
|
+
def __init__(self, *, config: ResearchAgentJobConfig) -> None:
|
|
219
|
+
self.config = config
|
|
220
|
+
self._job_id: str | None = None
|
|
221
|
+
|
|
222
|
+
@property
|
|
223
|
+
def job_id(self) -> str | None:
|
|
224
|
+
return self._job_id
|
|
225
|
+
|
|
226
|
+
@classmethod
|
|
227
|
+
def from_research_config(
|
|
228
|
+
cls,
|
|
229
|
+
*,
|
|
230
|
+
research: ResearchConfig,
|
|
231
|
+
repo_url: str,
|
|
232
|
+
backend_url: str,
|
|
233
|
+
api_key: str,
|
|
234
|
+
model: str | None = None,
|
|
235
|
+
max_agent_spend_usd: float | None = None,
|
|
236
|
+
) -> "ResearchAgentJob":
|
|
237
|
+
config = ResearchAgentJobConfig(
|
|
238
|
+
research=research,
|
|
239
|
+
repo_url=repo_url,
|
|
240
|
+
backend_url=backend_url,
|
|
241
|
+
api_key=api_key,
|
|
242
|
+
model=model,
|
|
243
|
+
max_agent_spend_usd=max_agent_spend_usd,
|
|
244
|
+
)
|
|
245
|
+
return cls(config=config)
|
|
246
|
+
|
|
247
|
+
@classmethod
|
|
248
|
+
def from_id(
|
|
249
|
+
cls,
|
|
250
|
+
*,
|
|
251
|
+
job_id: str,
|
|
252
|
+
backend_url: str,
|
|
253
|
+
api_key: str,
|
|
254
|
+
) -> "ResearchAgentJob":
|
|
255
|
+
research = ResearchConfig(task_description="Existing research job")
|
|
256
|
+
config = ResearchAgentJobConfig(
|
|
257
|
+
research=research,
|
|
258
|
+
repo_url="existing",
|
|
259
|
+
backend_url=backend_url,
|
|
260
|
+
api_key=api_key,
|
|
261
|
+
)
|
|
262
|
+
job = cls(config=config)
|
|
263
|
+
job._job_id = job_id
|
|
264
|
+
return job
|
|
265
|
+
|
|
266
|
+
def submit(self) -> str:
|
|
267
|
+
if self._job_id is not None:
|
|
268
|
+
raise RuntimeError("Job already submitted")
|
|
269
|
+
if OptimizationTool.GEPA in self.config.research.tools:
|
|
270
|
+
raise NotImplementedError("GEPA optimization is not yet fully supported")
|
|
271
|
+
self._job_id = "ra_pending"
|
|
272
|
+
return self._job_id
|
|
273
|
+
|
|
274
|
+
def poll_until_complete(self) -> dict[str, Any]:
|
|
275
|
+
if self._job_id is None:
|
|
276
|
+
raise RuntimeError("Job not submitted yet")
|
|
277
|
+
return {"job_id": self._job_id, "status": "submitted"}
|
|
278
|
+
|
|
279
|
+
def get_status(self) -> dict[str, Any]:
|
|
280
|
+
if self._job_id is None:
|
|
281
|
+
raise RuntimeError("Job not submitted yet")
|
|
282
|
+
return {"job_id": self._job_id, "status": "submitted"}
|
|
46
283
|
|
|
47
|
-
from .config import (
|
|
48
|
-
DatasetSource,
|
|
49
|
-
GEPAConfig,
|
|
50
|
-
MIPROConfig,
|
|
51
|
-
ModelProvider,
|
|
52
|
-
OptimizationTool,
|
|
53
|
-
PermittedModel,
|
|
54
|
-
PermittedModelsConfig,
|
|
55
|
-
ResearchConfig,
|
|
56
|
-
)
|
|
57
|
-
from .job import (
|
|
58
|
-
ResearchAgentJob,
|
|
59
|
-
ResearchAgentJobConfig,
|
|
60
|
-
ResearchAgentJobPoller,
|
|
61
|
-
)
|
|
62
284
|
|
|
63
285
|
__all__ = [
|
|
64
|
-
# CLI
|
|
65
|
-
"register",
|
|
66
|
-
# SDK - Main classes
|
|
67
286
|
"ResearchAgentJob",
|
|
68
287
|
"ResearchAgentJobConfig",
|
|
69
|
-
"ResearchAgentJobPoller",
|
|
70
|
-
# SDK - Config types
|
|
71
288
|
"ResearchConfig",
|
|
72
289
|
"DatasetSource",
|
|
73
290
|
"OptimizationTool",
|
|
@@ -77,10 +294,3 @@ __all__ = [
|
|
|
77
294
|
"PermittedModel",
|
|
78
295
|
"ModelProvider",
|
|
79
296
|
]
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def register(cli: Any) -> None:
|
|
83
|
-
"""Register the agent command with the CLI."""
|
|
84
|
-
from .cli import register as _register
|
|
85
|
-
|
|
86
|
-
_register(cli)
|
|
@@ -885,11 +885,192 @@ def build_prompt_learning_payload(
|
|
|
885
885
|
return PromptLearningBuildResult(payload=payload, task_url=final_task_url)
|
|
886
886
|
|
|
887
887
|
|
|
888
|
+
def build_prompt_learning_payload_from_mapping(
|
|
889
|
+
*,
|
|
890
|
+
raw_config: dict[str, Any],
|
|
891
|
+
task_url: str | None,
|
|
892
|
+
overrides: dict[str, Any],
|
|
893
|
+
allow_experimental: bool | None = None,
|
|
894
|
+
source_label: str = "programmatic",
|
|
895
|
+
) -> PromptLearningBuildResult:
|
|
896
|
+
"""Build payload for prompt learning job from a dictionary (programmatic use).
|
|
897
|
+
|
|
898
|
+
This is the same as build_prompt_learning_payload but accepts a dict instead of a file path.
|
|
899
|
+
Both functions route through the same PromptLearningConfig Pydantic validation.
|
|
900
|
+
|
|
901
|
+
Args:
|
|
902
|
+
raw_config: Configuration dictionary with the same structure as the TOML file.
|
|
903
|
+
Should have a 'prompt_learning' section.
|
|
904
|
+
task_url: Override for task_app_url
|
|
905
|
+
overrides: Config overrides (merged into config)
|
|
906
|
+
allow_experimental: Allow experimental models
|
|
907
|
+
source_label: Label for logging/error messages (default: "programmatic")
|
|
908
|
+
|
|
909
|
+
Returns:
|
|
910
|
+
PromptLearningBuildResult with payload and task_url
|
|
911
|
+
|
|
912
|
+
Example:
|
|
913
|
+
>>> result = build_prompt_learning_payload_from_mapping(
|
|
914
|
+
... raw_config={
|
|
915
|
+
... "prompt_learning": {
|
|
916
|
+
... "algorithm": "gepa",
|
|
917
|
+
... "task_app_url": "https://tunnel.example.com",
|
|
918
|
+
... "policy": {"model": "gpt-4o-mini", "provider": "openai"},
|
|
919
|
+
... "gepa": {...},
|
|
920
|
+
... }
|
|
921
|
+
... },
|
|
922
|
+
... task_url=None,
|
|
923
|
+
... overrides={},
|
|
924
|
+
... )
|
|
925
|
+
"""
|
|
926
|
+
ctx: dict[str, Any] = {"source": source_label}
|
|
927
|
+
log_info("build_prompt_learning_payload_from_mapping invoked", ctx=ctx)
|
|
928
|
+
from pydantic import ValidationError
|
|
929
|
+
|
|
930
|
+
# SDK-SIDE VALIDATION: Catch errors BEFORE sending to backend
|
|
931
|
+
from .validators import validate_prompt_learning_config
|
|
932
|
+
|
|
933
|
+
# Use a pseudo-path for error messages (validator expects Path object)
|
|
934
|
+
pseudo_path = Path(f"<{source_label}>")
|
|
935
|
+
validate_prompt_learning_config(raw_config, pseudo_path)
|
|
936
|
+
|
|
937
|
+
try:
|
|
938
|
+
pl_cfg = PromptLearningConfig.from_mapping(raw_config)
|
|
939
|
+
except ValidationError as exc:
|
|
940
|
+
# Format validation errors for dict-based config
|
|
941
|
+
lines: list[str] = []
|
|
942
|
+
for error in exc.errors():
|
|
943
|
+
loc = ".".join(str(part) for part in error.get("loc", ()))
|
|
944
|
+
msg = error.get("msg", "invalid value")
|
|
945
|
+
lines.append(f"{loc or '<root>'}: {msg}")
|
|
946
|
+
details = "\n".join(f" - {line}" for line in lines) or " - Invalid configuration"
|
|
947
|
+
raise click.ClickException(f"Config validation failed ({source_label}):\n{details}") from exc
|
|
948
|
+
|
|
949
|
+
# Early validation: Check required fields for GEPA
|
|
950
|
+
if pl_cfg.algorithm == "gepa":
|
|
951
|
+
if not pl_cfg.gepa:
|
|
952
|
+
raise click.ClickException(
|
|
953
|
+
"GEPA config missing: [prompt_learning.gepa] section is required"
|
|
954
|
+
)
|
|
955
|
+
if not pl_cfg.gepa.evaluation:
|
|
956
|
+
raise click.ClickException(
|
|
957
|
+
"GEPA config missing: [prompt_learning.gepa.evaluation] section is required"
|
|
958
|
+
)
|
|
959
|
+
train_seeds = getattr(pl_cfg.gepa.evaluation, "train_seeds", None) or getattr(pl_cfg.gepa.evaluation, "seeds", None)
|
|
960
|
+
if not train_seeds:
|
|
961
|
+
raise click.ClickException(
|
|
962
|
+
"GEPA config missing train_seeds: [prompt_learning.gepa.evaluation] must have 'train_seeds' or 'seeds' field"
|
|
963
|
+
)
|
|
964
|
+
val_seeds = getattr(pl_cfg.gepa.evaluation, "val_seeds", None) or getattr(pl_cfg.gepa.evaluation, "validation_seeds", None)
|
|
965
|
+
if not val_seeds:
|
|
966
|
+
raise click.ClickException(
|
|
967
|
+
"GEPA config missing val_seeds: [prompt_learning.gepa.evaluation] must have 'val_seeds' or 'validation_seeds' field"
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
cli_task_url = overrides.get("task_url") or task_url
|
|
971
|
+
env_task_url = os.environ.get("TASK_APP_URL")
|
|
972
|
+
config_task_url = (pl_cfg.task_app_url or "").strip() or None
|
|
973
|
+
|
|
974
|
+
# Resolve task_app_url with same precedence as file-based builder
|
|
975
|
+
if cli_task_url:
|
|
976
|
+
final_task_url = ConfigResolver.resolve(
|
|
977
|
+
"task_app_url",
|
|
978
|
+
cli_value=cli_task_url,
|
|
979
|
+
env_value=None,
|
|
980
|
+
config_value=config_task_url,
|
|
981
|
+
required=True,
|
|
982
|
+
)
|
|
983
|
+
elif config_task_url:
|
|
984
|
+
final_task_url = config_task_url
|
|
985
|
+
else:
|
|
986
|
+
final_task_url = ConfigResolver.resolve(
|
|
987
|
+
"task_app_url",
|
|
988
|
+
cli_value=None,
|
|
989
|
+
env_value=env_task_url,
|
|
990
|
+
config_value=None,
|
|
991
|
+
required=True,
|
|
992
|
+
)
|
|
993
|
+
assert final_task_url is not None
|
|
994
|
+
|
|
995
|
+
# Get task_app_api_key from config or environment
|
|
996
|
+
config_api_key = (pl_cfg.task_app_api_key or "").strip() or None
|
|
997
|
+
cli_api_key = overrides.get("task_app_api_key")
|
|
998
|
+
env_api_key = os.environ.get("ENVIRONMENT_API_KEY")
|
|
999
|
+
task_app_api_key = ConfigResolver.resolve(
|
|
1000
|
+
"task_app_api_key",
|
|
1001
|
+
cli_value=cli_api_key,
|
|
1002
|
+
env_value=env_api_key,
|
|
1003
|
+
config_value=config_api_key,
|
|
1004
|
+
required=True,
|
|
1005
|
+
)
|
|
1006
|
+
|
|
1007
|
+
# Build config dict for backend
|
|
1008
|
+
config_dict = pl_cfg.to_dict()
|
|
1009
|
+
|
|
1010
|
+
# Ensure task_app_url and task_app_api_key are set
|
|
1011
|
+
pl_section = config_dict.get("prompt_learning", {})
|
|
1012
|
+
if isinstance(pl_section, dict):
|
|
1013
|
+
pl_section["task_app_url"] = final_task_url
|
|
1014
|
+
pl_section["task_app_api_key"] = task_app_api_key
|
|
1015
|
+
|
|
1016
|
+
# GEPA: Extract train_seeds from nested structure
|
|
1017
|
+
if pl_cfg.algorithm == "gepa" and pl_cfg.gepa:
|
|
1018
|
+
train_seeds = None
|
|
1019
|
+
if pl_cfg.gepa.evaluation:
|
|
1020
|
+
train_seeds = getattr(pl_cfg.gepa.evaluation, "train_seeds", None) or getattr(pl_cfg.gepa.evaluation, "seeds", None)
|
|
1021
|
+
|
|
1022
|
+
if train_seeds and not pl_section.get("train_seeds"):
|
|
1023
|
+
pl_section["train_seeds"] = train_seeds
|
|
1024
|
+
if train_seeds and not pl_section.get("evaluation_seeds"):
|
|
1025
|
+
pl_section["evaluation_seeds"] = train_seeds
|
|
1026
|
+
else:
|
|
1027
|
+
config_dict["prompt_learning"] = {
|
|
1028
|
+
"task_app_url": final_task_url,
|
|
1029
|
+
"task_app_api_key": task_app_api_key,
|
|
1030
|
+
}
|
|
1031
|
+
|
|
1032
|
+
# Build payload matching backend API format
|
|
1033
|
+
config_overrides = overrides.get("overrides", {}) if "overrides" in overrides else overrides
|
|
1034
|
+
config_overrides = {
|
|
1035
|
+
k: v for k, v in config_overrides.items()
|
|
1036
|
+
if k not in ("backend", "task_url", "metadata", "auto_start")
|
|
1037
|
+
}
|
|
1038
|
+
|
|
1039
|
+
# Merge overrides into config_dict
|
|
1040
|
+
if config_overrides:
|
|
1041
|
+
from synth_ai.cli.local.experiment_queue.config_utils import _deep_update
|
|
1042
|
+
_deep_update(config_dict, config_overrides)
|
|
1043
|
+
|
|
1044
|
+
# Final validation
|
|
1045
|
+
if "prompt_learning" not in config_dict:
|
|
1046
|
+
raise ValueError(
|
|
1047
|
+
"config_dict must have 'prompt_learning' key. "
|
|
1048
|
+
f"Found keys: {list(config_dict.keys())}"
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
payload: dict[str, Any] = {
|
|
1052
|
+
"algorithm": pl_cfg.algorithm,
|
|
1053
|
+
"config_body": config_dict,
|
|
1054
|
+
"overrides": config_overrides,
|
|
1055
|
+
"metadata": overrides.get("metadata", {}),
|
|
1056
|
+
"auto_start": overrides.get("auto_start", True),
|
|
1057
|
+
}
|
|
1058
|
+
|
|
1059
|
+
backend = overrides.get("backend")
|
|
1060
|
+
if backend:
|
|
1061
|
+
metadata_default: dict[str, Any] = {}
|
|
1062
|
+
metadata = cast(dict[str, Any], payload.setdefault("metadata", metadata_default))
|
|
1063
|
+
metadata["backend_base_url"] = ensure_api_base(str(backend))
|
|
1064
|
+
|
|
1065
|
+
return PromptLearningBuildResult(payload=payload, task_url=final_task_url)
|
|
1066
|
+
|
|
1067
|
+
|
|
888
1068
|
__all__ = [
|
|
889
1069
|
"PromptLearningBuildResult",
|
|
890
1070
|
"RLBuildResult",
|
|
891
1071
|
"SFTBuildResult",
|
|
892
1072
|
"build_prompt_learning_payload",
|
|
1073
|
+
"build_prompt_learning_payload_from_mapping",
|
|
893
1074
|
"build_rl_payload",
|
|
894
1075
|
"build_sft_payload",
|
|
895
1076
|
]
|