synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +13 -13
- synth_ai/cli/__init__.py +6 -15
- synth_ai/cli/commands/eval/__init__.py +6 -15
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +236 -1091
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +44 -117
- synth_ai/cli/commands/filter/core.py +7 -7
- synth_ai/cli/commands/filter/validation.py +2 -2
- synth_ai/cli/commands/smoke/core.py +7 -17
- synth_ai/cli/commands/status/__init__.py +1 -64
- synth_ai/cli/commands/status/client.py +50 -151
- synth_ai/cli/commands/status/config.py +3 -83
- synth_ai/cli/commands/status/errors.py +4 -13
- synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +18 -63
- synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
- synth_ai/cli/commands/status/subcommands/models.py +18 -62
- synth_ai/cli/commands/status/subcommands/runs.py +16 -63
- synth_ai/cli/commands/status/subcommands/session.py +67 -172
- synth_ai/cli/commands/status/subcommands/summary.py +24 -32
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +16 -107
- synth_ai/cli/commands/train/__init__.py +18 -20
- synth_ai/cli/commands/train/errors.py +3 -3
- synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
- synth_ai/cli/commands/train/validation.py +7 -7
- synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
- synth_ai/cli/commands/train/verifier_validation.py +235 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/math/config.toml +0 -1
- synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
- synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
- synth_ai/cli/lib/apps/task_app.py +12 -13
- synth_ai/cli/lib/task_app_discovery.py +6 -6
- synth_ai/cli/lib/train_cfgs.py +10 -10
- synth_ai/cli/task_apps/__init__.py +11 -0
- synth_ai/cli/task_apps/commands.py +7 -15
- synth_ai/core/env.py +12 -1
- synth_ai/core/errors.py +1 -2
- synth_ai/core/integrations/cloudflare.py +209 -33
- synth_ai/core/tracing_v3/abstractions.py +46 -0
- synth_ai/data/__init__.py +3 -30
- synth_ai/data/enums.py +1 -20
- synth_ai/data/rewards.py +100 -3
- synth_ai/products/graph_evolve/__init__.py +1 -2
- synth_ai/products/graph_evolve/config.py +16 -16
- synth_ai/products/graph_evolve/converters/__init__.py +3 -3
- synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +45 -35
- synth_ai/sdk/api/eval/__init__.py +33 -0
- synth_ai/sdk/api/eval/job.py +732 -0
- synth_ai/sdk/api/research_agent/__init__.py +276 -66
- synth_ai/sdk/api/train/builders.py +181 -0
- synth_ai/sdk/api/train/cli.py +41 -33
- synth_ai/sdk/api/train/configs/__init__.py +6 -4
- synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
- synth_ai/sdk/api/train/configs/rl.py +264 -16
- synth_ai/sdk/api/train/configs/sft.py +165 -1
- synth_ai/sdk/api/train/graph_validators.py +12 -12
- synth_ai/sdk/api/train/graphgen.py +169 -51
- synth_ai/sdk/api/train/graphgen_models.py +95 -45
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +36 -0
- synth_ai/sdk/api/train/prompt_learning.py +390 -60
- synth_ai/sdk/api/train/rl.py +41 -5
- synth_ai/sdk/api/train/sft.py +2 -0
- synth_ai/sdk/api/train/task_app.py +20 -0
- synth_ai/sdk/api/train/validators.py +17 -17
- synth_ai/sdk/graphs/completions.py +239 -33
- synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
- synth_ai/sdk/learning/__init__.py +35 -5
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +294 -0
- synth_ai/sdk/learning/prompt_learning_client.py +1 -1
- synth_ai/sdk/learning/prompt_learning_types.py +2 -1
- synth_ai/sdk/learning/rl/__init__.py +0 -4
- synth_ai/sdk/learning/rl/contracts.py +0 -4
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +93 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +49 -0
- synth_ai/sdk/streaming/handlers.py +6 -6
- synth_ai/sdk/streaming/streamer.py +10 -6
- synth_ai/sdk/task/__init__.py +18 -5
- synth_ai/sdk/task/apps/__init__.py +37 -1
- synth_ai/sdk/task/client.py +9 -1
- synth_ai/sdk/task/config.py +6 -11
- synth_ai/sdk/task/contracts.py +137 -95
- synth_ai/sdk/task/in_process.py +32 -22
- synth_ai/sdk/task/in_process_runner.py +9 -4
- synth_ai/sdk/task/rubrics/__init__.py +2 -3
- synth_ai/sdk/task/rubrics/loaders.py +4 -4
- synth_ai/sdk/task/rubrics/strict.py +3 -4
- synth_ai/sdk/task/server.py +76 -16
- synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
- synth_ai/sdk/task/validators.py +34 -49
- synth_ai/sdk/training/__init__.py +7 -16
- synth_ai/sdk/tunnels/__init__.py +118 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/sdk/tunnels/tunneled_api.py +363 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
- synth_ai/cli/commands/baseline/__init__.py +0 -12
- synth_ai/cli/commands/baseline/core.py +0 -636
- synth_ai/cli/commands/baseline/list.py +0 -94
- synth_ai/cli/commands/eval/errors.py +0 -81
- synth_ai/cli/commands/status/formatters.py +0 -164
- synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
- synth_ai/cli/commands/status/subcommands/usage.py +0 -203
- synth_ai/cli/commands/train/judge_validation.py +0 -305
- synth_ai/cli/usage.py +0 -159
- synth_ai/data/specs.py +0 -36
- synth_ai/sdk/api/research_agent/cli.py +0 -428
- synth_ai/sdk/api/research_agent/config.py +0 -357
- synth_ai/sdk/api/research_agent/job.py +0 -717
- synth_ai/sdk/baseline/__init__.py +0 -25
- synth_ai/sdk/baseline/config.py +0 -209
- synth_ai/sdk/baseline/discovery.py +0 -216
- synth_ai/sdk/baseline/execution.py +0 -154
- synth_ai/sdk/judging/__init__.py +0 -15
- synth_ai/sdk/judging/base.py +0 -24
- synth_ai/sdk/judging/client.py +0 -191
- synth_ai/sdk/judging/types.py +0 -42
- synth_ai/sdk/research_agent/__init__.py +0 -34
- synth_ai/sdk/research_agent/container_builder.py +0 -328
- synth_ai/sdk/research_agent/container_spec.py +0 -198
- synth_ai/sdk/research_agent/defaults.py +0 -34
- synth_ai/sdk/research_agent/results_collector.py +0 -69
- synth_ai/sdk/specs/__init__.py +0 -46
- synth_ai/sdk/specs/dataclasses.py +0 -149
- synth_ai/sdk/specs/loader.py +0 -144
- synth_ai/sdk/specs/serializer.py +0 -199
- synth_ai/sdk/specs/validation.py +0 -250
- synth_ai/sdk/tracing/__init__.py +0 -39
- synth_ai/sdk/usage/__init__.py +0 -37
- synth_ai/sdk/usage/client.py +0 -171
- synth_ai/sdk/usage/models.py +0 -261
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
synth_ai/sdk/judging/client.py
DELETED
|
@@ -1,191 +0,0 @@
|
|
|
1
|
-
"""Experimental Judge API client.
|
|
2
|
-
|
|
3
|
-
This surface is experimental and subject to change without notice.
|
|
4
|
-
Set environment variable `SYNTH_SILENCE_EXPERIMENTAL=1` to silence warnings.
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
from __future__ import annotations
|
|
8
|
-
|
|
9
|
-
import os
|
|
10
|
-
import warnings
|
|
11
|
-
from typing import Any, Literal, TypedDict
|
|
12
|
-
|
|
13
|
-
from synth_ai.core.http import AsyncHttpClient, HTTPError
|
|
14
|
-
from synth_ai.core.tracing_v3.serialization import normalize_for_json
|
|
15
|
-
from synth_ai.sdk.graphs import VerifierClient as GraphVerifierClient
|
|
16
|
-
|
|
17
|
-
Provider = Literal["groq", "gemini"]
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class JudgeOptions(TypedDict, total=False):
|
|
21
|
-
event: bool
|
|
22
|
-
outcome: bool
|
|
23
|
-
rubric_id: str
|
|
24
|
-
rubric_overrides: dict[str, Any]
|
|
25
|
-
provider: Provider
|
|
26
|
-
model: str
|
|
27
|
-
max_concurrency: int
|
|
28
|
-
verifier_type: str
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class JudgeScoreResponse(TypedDict, total=False):
|
|
32
|
-
status: str
|
|
33
|
-
event_rewards: list[dict[str, Any]]
|
|
34
|
-
outcome_reward: dict[str, Any]
|
|
35
|
-
details: dict[str, Any]
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class JudgeClient:
|
|
39
|
-
"""Legacy client for LLM-based evaluation of task app traces.
|
|
40
|
-
|
|
41
|
-
This client provides programmatic access to Synth AI's judge API, which uses
|
|
42
|
-
LLMs to evaluate task execution traces and generate rewards. The judge can
|
|
43
|
-
evaluate both event-level (step-by-step) and outcome-level (episode-level) rewards.
|
|
44
|
-
|
|
45
|
-
.. warning::
|
|
46
|
-
This API is experimental and subject to change without notice.
|
|
47
|
-
Set `SYNTH_SILENCE_EXPERIMENTAL=1` to silence warnings.
|
|
48
|
-
|
|
49
|
-
Example:
|
|
50
|
-
>>> from synth_ai.sdk.judging import JudgeClient, JudgeOptions
|
|
51
|
-
>>>
|
|
52
|
-
>>> client = JudgeClient(
|
|
53
|
-
... base_url="https://api.usesynth.ai",
|
|
54
|
-
... api_key=os.environ["SYNTH_API_KEY"],
|
|
55
|
-
... )
|
|
56
|
-
>>>
|
|
57
|
-
>>> # Score a trace with outcome reward
|
|
58
|
-
>>> result = await client.score(
|
|
59
|
-
... trace=my_trace_dict,
|
|
60
|
-
... policy_name="my_policy",
|
|
61
|
-
... task_app_id="heartdisease",
|
|
62
|
-
... options=JudgeOptions(
|
|
63
|
-
... outcome=True,
|
|
64
|
-
... rubric_id="accuracy",
|
|
65
|
-
... provider="groq",
|
|
66
|
-
... model="llama-3.1-8b-instant",
|
|
67
|
-
... ),
|
|
68
|
-
... )
|
|
69
|
-
>>>
|
|
70
|
-
>>> print(f"Outcome reward: {result['outcome_reward']}")
|
|
71
|
-
"""
|
|
72
|
-
|
|
73
|
-
def __init__(self, base_url: str, api_key: str, *, timeout: float = 60.0) -> None:
|
|
74
|
-
"""Initialize the judge client.
|
|
75
|
-
|
|
76
|
-
Args:
|
|
77
|
-
base_url: Base URL for the Synth AI API
|
|
78
|
-
api_key: API key for authentication
|
|
79
|
-
timeout: Request timeout in seconds (default: 60.0)
|
|
80
|
-
"""
|
|
81
|
-
_silence = (os.getenv("SYNTH_SILENCE_EXPERIMENTAL") or "").strip().lower()
|
|
82
|
-
if _silence not in {"1", "true", "t", "yes", "y", "on"}:
|
|
83
|
-
warnings.warn(
|
|
84
|
-
"Legacy API: synth_ai.sdk.judging.JudgeClient is legacy. "
|
|
85
|
-
"Use synth_ai.sdk.graphs.VerifierClient or GraphCompletionsClient instead.",
|
|
86
|
-
UserWarning,
|
|
87
|
-
stacklevel=2,
|
|
88
|
-
)
|
|
89
|
-
self._base = base_url.rstrip("/")
|
|
90
|
-
self._key = api_key
|
|
91
|
-
self._timeout = timeout
|
|
92
|
-
|
|
93
|
-
async def score(
|
|
94
|
-
self,
|
|
95
|
-
*,
|
|
96
|
-
trace: dict[str, Any] | Any,
|
|
97
|
-
policy_name: str,
|
|
98
|
-
task_app_id: str,
|
|
99
|
-
options: JudgeOptions,
|
|
100
|
-
rubric: dict[str, Any] | None = None,
|
|
101
|
-
verifier_type: str | None = None,
|
|
102
|
-
task_app_base_url: str | None = None,
|
|
103
|
-
) -> JudgeScoreResponse:
|
|
104
|
-
"""Score a task execution trace using LLM-based evaluation.
|
|
105
|
-
|
|
106
|
-
This method sends a trace to the judge API, which evaluates it according
|
|
107
|
-
to the provided rubric and returns event-level and/or outcome-level rewards.
|
|
108
|
-
|
|
109
|
-
Args:
|
|
110
|
-
trace: Task execution trace (SessionTrace dict or compatible object)
|
|
111
|
-
policy_name: Name of the policy that generated this trace
|
|
112
|
-
task_app_id: Identifier for the task app (e.g., "heartdisease")
|
|
113
|
-
options: Judge configuration options:
|
|
114
|
-
- event: Whether to generate event-level rewards (default: False)
|
|
115
|
-
- outcome: Whether to generate outcome-level reward (default: False)
|
|
116
|
-
- rubric_id: Rubric identifier to use for evaluation
|
|
117
|
-
- rubric_overrides: Optional rubric modifications
|
|
118
|
-
- provider: LLM provider ("groq" or "gemini")
|
|
119
|
-
- model: Model identifier (e.g., "llama-3.1-8b-instant")
|
|
120
|
-
- max_concurrency: Max concurrent judge calls (default: 1)
|
|
121
|
-
rubric: Optional explicit rubric criteria (event/outcome lists)
|
|
122
|
-
verifier_type: Optional zero-shot verifier graph ID (e.g., "zero_shot_verifier_single")
|
|
123
|
-
task_app_base_url: Optional base URL for task app (for rubric fetching)
|
|
124
|
-
|
|
125
|
-
Returns:
|
|
126
|
-
JudgeScoreResponse with:
|
|
127
|
-
- status: "ok" or error status
|
|
128
|
-
- event_rewards: List of event-level reward dicts (if event=True)
|
|
129
|
-
- outcome_reward: Outcome-level reward dict (if outcome=True)
|
|
130
|
-
- details: Additional evaluation details
|
|
131
|
-
|
|
132
|
-
Raises:
|
|
133
|
-
ValueError: If validation fails or rubric is invalid
|
|
134
|
-
PermissionError: If authentication fails
|
|
135
|
-
FileNotFoundError: If task app or rubric not found
|
|
136
|
-
Exception: For rate limiting or transient errors
|
|
137
|
-
"""
|
|
138
|
-
trace_payload = normalize_for_json(trace)
|
|
139
|
-
task_app_payload = {"id": task_app_id}
|
|
140
|
-
if task_app_base_url:
|
|
141
|
-
task_app_payload["base_url"] = task_app_base_url
|
|
142
|
-
|
|
143
|
-
selected_verifier = verifier_type or (options or {}).get("verifier_type")
|
|
144
|
-
if selected_verifier:
|
|
145
|
-
graph_input = {
|
|
146
|
-
"policy_name": policy_name,
|
|
147
|
-
"task_app": task_app_payload,
|
|
148
|
-
"session_trace": trace_payload,
|
|
149
|
-
"trace": trace_payload,
|
|
150
|
-
"options": options or {},
|
|
151
|
-
}
|
|
152
|
-
if rubric is not None:
|
|
153
|
-
graph_input["rubric"] = normalize_for_json(rubric)
|
|
154
|
-
body = {"job_id": selected_verifier, "input": graph_input}
|
|
155
|
-
else:
|
|
156
|
-
body = {
|
|
157
|
-
"policy_name": policy_name,
|
|
158
|
-
"task_app": task_app_payload,
|
|
159
|
-
"trace": trace_payload,
|
|
160
|
-
"options": options or {},
|
|
161
|
-
}
|
|
162
|
-
if rubric is not None:
|
|
163
|
-
body["rubric"] = normalize_for_json(rubric)
|
|
164
|
-
try:
|
|
165
|
-
async with AsyncHttpClient(self._base, self._key, timeout=self._timeout) as http:
|
|
166
|
-
if selected_verifier:
|
|
167
|
-
js = await http.post_json("/api/graphs/completions", json=body)
|
|
168
|
-
if isinstance(js, dict) and "output" in js:
|
|
169
|
-
js = js["output"]
|
|
170
|
-
else:
|
|
171
|
-
js = await http.post_json("/api/judge/v1/score", json=body)
|
|
172
|
-
if not isinstance(js, dict):
|
|
173
|
-
raise ValueError("invalid_judge_response_shape")
|
|
174
|
-
return js # type: ignore[return-value]
|
|
175
|
-
except HTTPError as err: # map to friendlier exceptions
|
|
176
|
-
status = int(getattr(err, "status", 0) or 0)
|
|
177
|
-
if status in (400, 422):
|
|
178
|
-
raise ValueError(f"judge_validation_error: {err.detail}") from err
|
|
179
|
-
if status in (401, 403):
|
|
180
|
-
raise PermissionError(f"judge_auth_error: {err.detail}") from err
|
|
181
|
-
if status == 404:
|
|
182
|
-
raise FileNotFoundError(f"judge_route_not_found: {err.detail}") from err
|
|
183
|
-
if status == 429:
|
|
184
|
-
raise Exception("judge_rate_limited") from err # replace with RetryLater in future
|
|
185
|
-
if status >= 500:
|
|
186
|
-
raise Exception("judge_transient_error") from err # replace with TransientError in future
|
|
187
|
-
raise
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
class VerifierClient(GraphVerifierClient):
|
|
191
|
-
"""Deprecated alias for graph-based VerifierClient."""
|
synth_ai/sdk/judging/types.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from typing import Literal, TypedDict
|
|
4
|
-
|
|
5
|
-
Track = Literal["process", "reasoning", "progress", "outcome"]
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class Judgement(TypedDict, total=False):
|
|
9
|
-
key: str
|
|
10
|
-
title: str
|
|
11
|
-
description: str
|
|
12
|
-
score: float
|
|
13
|
-
reason: str
|
|
14
|
-
confidence: float
|
|
15
|
-
scale: Literal["binary", "bounded", "count", "custom"]
|
|
16
|
-
source: dict
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class RewardJudgement(TypedDict, total=False):
|
|
20
|
-
judgement: Judgement
|
|
21
|
-
scope: Literal["step", "event", "outcome"]
|
|
22
|
-
turn: int | None
|
|
23
|
-
episode_id: str | None
|
|
24
|
-
reward_value: float | None
|
|
25
|
-
links: dict
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class TrackAggregate(TypedDict, total=False):
|
|
29
|
-
mean: float
|
|
30
|
-
median: float
|
|
31
|
-
std: float
|
|
32
|
-
n: int
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class RewardMetadata(TypedDict, total=False):
|
|
36
|
-
per_window: list[RewardJudgement]
|
|
37
|
-
aggregates: dict[Track, TrackAggregate]
|
|
38
|
-
overall: dict[str, float] # {"final_outcome_score": float}
|
|
39
|
-
rubric: dict # {"ids": {...}, "hash": "..."}
|
|
40
|
-
model_info: dict # {"model": "...", ...}
|
|
41
|
-
|
|
42
|
-
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
from synth_ai.sdk.research_agent.container_builder import (
|
|
2
|
-
ContainerBackend,
|
|
3
|
-
DockerBackend,
|
|
4
|
-
ModalBackend,
|
|
5
|
-
get_backend,
|
|
6
|
-
)
|
|
7
|
-
from synth_ai.sdk.research_agent.container_spec import ContainerSpec
|
|
8
|
-
from synth_ai.sdk.research_agent.defaults import (
|
|
9
|
-
DEFAULT_BACKEND,
|
|
10
|
-
DEFAULT_BASE_IMAGE,
|
|
11
|
-
DEFAULT_INSTRUCTIONS,
|
|
12
|
-
DEFAULT_PACKAGES,
|
|
13
|
-
DEFAULT_PYTHON_VERSION,
|
|
14
|
-
DEFAULT_REASONING_EFFORT,
|
|
15
|
-
DEFAULT_RESULT_PATTERNS,
|
|
16
|
-
)
|
|
17
|
-
from synth_ai.sdk.research_agent.results_collector import ResultsCollector
|
|
18
|
-
|
|
19
|
-
__all__ = [
|
|
20
|
-
"ContainerBackend",
|
|
21
|
-
"ContainerSpec",
|
|
22
|
-
"DockerBackend",
|
|
23
|
-
"ModalBackend",
|
|
24
|
-
"ResultsCollector",
|
|
25
|
-
"get_backend",
|
|
26
|
-
"DEFAULT_BACKEND",
|
|
27
|
-
"DEFAULT_BASE_IMAGE",
|
|
28
|
-
"DEFAULT_INSTRUCTIONS",
|
|
29
|
-
"DEFAULT_PACKAGES",
|
|
30
|
-
"DEFAULT_PYTHON_VERSION",
|
|
31
|
-
"DEFAULT_REASONING_EFFORT",
|
|
32
|
-
"DEFAULT_RESULT_PATTERNS",
|
|
33
|
-
]
|
|
34
|
-
|
|
@@ -1,328 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import asyncio
|
|
4
|
-
import base64
|
|
5
|
-
import contextlib
|
|
6
|
-
import fnmatch
|
|
7
|
-
import io
|
|
8
|
-
import tarfile
|
|
9
|
-
import tempfile
|
|
10
|
-
import time
|
|
11
|
-
import uuid
|
|
12
|
-
from abc import ABC, abstractmethod
|
|
13
|
-
from pathlib import Path
|
|
14
|
-
from typing import Dict, Iterable, Tuple
|
|
15
|
-
|
|
16
|
-
from synth_ai.sdk.research_agent.container_spec import ContainerSpec
|
|
17
|
-
from synth_ai.sdk.research_agent.defaults import DEFAULT_BACKEND
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class ContainerBackend(ABC):
|
|
21
|
-
"""Abstract base for container execution backends."""
|
|
22
|
-
|
|
23
|
-
@abstractmethod
|
|
24
|
-
async def provision(self, spec: ContainerSpec) -> str:
|
|
25
|
-
"""Provision a new container and return its id/handle."""
|
|
26
|
-
|
|
27
|
-
@abstractmethod
|
|
28
|
-
async def execute(
|
|
29
|
-
self,
|
|
30
|
-
container_id: str,
|
|
31
|
-
command: str,
|
|
32
|
-
*,
|
|
33
|
-
env: Dict[str, str] | None = None,
|
|
34
|
-
workdir: Path | None = None,
|
|
35
|
-
) -> Dict[str, str | int]:
|
|
36
|
-
"""Execute a command in the container."""
|
|
37
|
-
|
|
38
|
-
@abstractmethod
|
|
39
|
-
async def collect_artifacts(self, container_id: str, patterns: Iterable[str]) -> Dict[str, bytes]:
|
|
40
|
-
"""Pull artifacts that match any of the glob patterns."""
|
|
41
|
-
|
|
42
|
-
@abstractmethod
|
|
43
|
-
async def destroy(self, container_id: str) -> None:
|
|
44
|
-
"""Tear down container resources."""
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class DockerBackend(ContainerBackend):
|
|
48
|
-
"""Docker implementation using docker-py."""
|
|
49
|
-
|
|
50
|
-
def __init__(self, *, client=None):
|
|
51
|
-
self._client = client
|
|
52
|
-
self._containers: Dict[str, Tuple[object, ContainerSpec]] = {}
|
|
53
|
-
|
|
54
|
-
def _ensure_client(self):
|
|
55
|
-
if self._client is None:
|
|
56
|
-
try:
|
|
57
|
-
import docker # type: ignore
|
|
58
|
-
except ImportError as exc:
|
|
59
|
-
raise RuntimeError("docker SDK is not installed. Add docker>=7.0.0 to dependencies.") from exc
|
|
60
|
-
self._client = docker.from_env()
|
|
61
|
-
return self._client
|
|
62
|
-
|
|
63
|
-
async def provision(self, spec: ContainerSpec) -> str:
|
|
64
|
-
spec.validate()
|
|
65
|
-
client = self._ensure_client()
|
|
66
|
-
context_bytes = spec.build_context()
|
|
67
|
-
image_tag = f"research-agent:{int(time.time())}"
|
|
68
|
-
|
|
69
|
-
def _build():
|
|
70
|
-
return client.images.build(
|
|
71
|
-
fileobj=io.BytesIO(context_bytes),
|
|
72
|
-
custom_context=True,
|
|
73
|
-
rm=True,
|
|
74
|
-
nocache=True,
|
|
75
|
-
tag=image_tag,
|
|
76
|
-
buildargs=spec.build_args,
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
loop = asyncio.get_running_loop()
|
|
80
|
-
image, _ = await loop.run_in_executor(None, _build)
|
|
81
|
-
|
|
82
|
-
container = client.containers.create(
|
|
83
|
-
image=image.id,
|
|
84
|
-
command="sleep infinity",
|
|
85
|
-
environment={**spec.env_vars, **spec.secrets},
|
|
86
|
-
tty=True,
|
|
87
|
-
detach=True,
|
|
88
|
-
working_dir=str(spec.workdir),
|
|
89
|
-
)
|
|
90
|
-
container.start()
|
|
91
|
-
self._containers[container.id] = (container, spec)
|
|
92
|
-
return container.id
|
|
93
|
-
|
|
94
|
-
async def execute(
|
|
95
|
-
self,
|
|
96
|
-
container_id: str,
|
|
97
|
-
command: str,
|
|
98
|
-
*,
|
|
99
|
-
env: Dict[str, str] | None = None,
|
|
100
|
-
workdir: Path | None = None,
|
|
101
|
-
) -> Dict[str, str | int]:
|
|
102
|
-
container, spec = self._containers[container_id]
|
|
103
|
-
if env is None:
|
|
104
|
-
env = {}
|
|
105
|
-
exec_env = {**spec.env_vars, **spec.secrets, **env}
|
|
106
|
-
workdir_str = str(workdir or spec.workdir)
|
|
107
|
-
|
|
108
|
-
def _run():
|
|
109
|
-
result = container.exec_run( # type: ignore[attr-defined]
|
|
110
|
-
cmd=["bash", "-lc", command],
|
|
111
|
-
environment=exec_env,
|
|
112
|
-
workdir=workdir_str,
|
|
113
|
-
demux=True,
|
|
114
|
-
)
|
|
115
|
-
stdout, stderr = result.output
|
|
116
|
-
return {
|
|
117
|
-
"exit_code": result.exit_code,
|
|
118
|
-
"stdout": (stdout or b"").decode(),
|
|
119
|
-
"stderr": (stderr or b"").decode(),
|
|
120
|
-
}
|
|
121
|
-
|
|
122
|
-
loop = asyncio.get_running_loop()
|
|
123
|
-
return await loop.run_in_executor(None, _run)
|
|
124
|
-
|
|
125
|
-
async def collect_artifacts(self, container_id: str, patterns: Iterable[str]) -> Dict[str, bytes]:
|
|
126
|
-
container, spec = self._containers[container_id]
|
|
127
|
-
loop = asyncio.get_running_loop()
|
|
128
|
-
|
|
129
|
-
def _pull():
|
|
130
|
-
try:
|
|
131
|
-
stream, _ = container.get_archive(str(spec.artifacts_dir)) # type: ignore[attr-defined]
|
|
132
|
-
except Exception:
|
|
133
|
-
return {}
|
|
134
|
-
tar_bytes = b"".join(stream)
|
|
135
|
-
collected: Dict[str, bytes] = {}
|
|
136
|
-
with tarfile.open(fileobj=io.BytesIO(tar_bytes), mode="r:*") as tar:
|
|
137
|
-
for member in tar.getmembers():
|
|
138
|
-
if not member.isfile():
|
|
139
|
-
continue
|
|
140
|
-
relative_name = str(Path(member.name).name)
|
|
141
|
-
if not any(fnmatch.fnmatch(relative_name, pat) for pat in patterns):
|
|
142
|
-
continue
|
|
143
|
-
file_obj = tar.extractfile(member)
|
|
144
|
-
if file_obj:
|
|
145
|
-
collected[relative_name] = file_obj.read()
|
|
146
|
-
return collected
|
|
147
|
-
|
|
148
|
-
return await loop.run_in_executor(None, _pull)
|
|
149
|
-
|
|
150
|
-
async def destroy(self, container_id: str) -> None:
|
|
151
|
-
container, _ = self._containers.pop(container_id, (None, None))
|
|
152
|
-
if container is None:
|
|
153
|
-
return
|
|
154
|
-
|
|
155
|
-
def _stop():
|
|
156
|
-
with contextlib.suppress(Exception):
|
|
157
|
-
container.kill() # type: ignore[attr-defined]
|
|
158
|
-
with contextlib.suppress(Exception):
|
|
159
|
-
container.remove(force=True) # type: ignore[attr-defined]
|
|
160
|
-
|
|
161
|
-
loop = asyncio.get_running_loop()
|
|
162
|
-
await loop.run_in_executor(None, _stop)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
class ModalBackend(ContainerBackend):
|
|
166
|
-
"""Modal implementation using modal SDK. Returns artifacts inline from execute()."""
|
|
167
|
-
|
|
168
|
-
def __init__(self):
|
|
169
|
-
self._runs: Dict[str, Dict[str, object]] = {}
|
|
170
|
-
|
|
171
|
-
def _write_build_context(self, spec: ContainerSpec) -> tempfile.TemporaryDirectory:
|
|
172
|
-
"""Materialize Dockerfile + overlay files to a temp dir for Modal build."""
|
|
173
|
-
temp_dir = tempfile.TemporaryDirectory()
|
|
174
|
-
ctx = Path(temp_dir.name)
|
|
175
|
-
(ctx / "Dockerfile").write_text(spec.to_dockerfile())
|
|
176
|
-
|
|
177
|
-
overlay_root = ctx / "overlay_files"
|
|
178
|
-
for rel_path, content in spec.rendered_overlay_files().items():
|
|
179
|
-
target = overlay_root / rel_path
|
|
180
|
-
target.parent.mkdir(parents=True, exist_ok=True)
|
|
181
|
-
target.write_bytes(content)
|
|
182
|
-
|
|
183
|
-
for rel_path, content in spec.files.items():
|
|
184
|
-
if not str(rel_path).startswith("/"):
|
|
185
|
-
continue
|
|
186
|
-
data = content.encode() if isinstance(content, str) else content
|
|
187
|
-
target = ctx / str(rel_path).lstrip("/")
|
|
188
|
-
target.parent.mkdir(parents=True, exist_ok=True)
|
|
189
|
-
target.write_bytes(data)
|
|
190
|
-
|
|
191
|
-
return temp_dir
|
|
192
|
-
|
|
193
|
-
async def provision(self, spec: ContainerSpec) -> str:
|
|
194
|
-
spec.validate()
|
|
195
|
-
try:
|
|
196
|
-
import modal # type: ignore
|
|
197
|
-
except ImportError as exc: # pragma: no cover - runtime import guard
|
|
198
|
-
raise RuntimeError("modal SDK is not installed. Add modal>=1.1.1 to dependencies.") from exc
|
|
199
|
-
|
|
200
|
-
ctx_dir = self._write_build_context(spec)
|
|
201
|
-
loop = asyncio.get_running_loop()
|
|
202
|
-
|
|
203
|
-
def _build_image():
|
|
204
|
-
return modal.Image.from_dockerfile(
|
|
205
|
-
path=ctx_dir.name,
|
|
206
|
-
build_args=spec.build_args,
|
|
207
|
-
force_build=True,
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
image = await loop.run_in_executor(None, _build_image)
|
|
211
|
-
|
|
212
|
-
# Combine env_vars and secrets into a Modal Secret
|
|
213
|
-
# Modal function decorator doesn't accept 'env' parameter directly
|
|
214
|
-
# Environment variables must be passed via secrets
|
|
215
|
-
combined_env: dict[str, str | None] = {**spec.env_vars, **spec.secrets}
|
|
216
|
-
secret_obj = None
|
|
217
|
-
if combined_env:
|
|
218
|
-
secret_obj = modal.Secret.from_dict(combined_env)
|
|
219
|
-
|
|
220
|
-
app = modal.App(f"oneshot-research-{int(time.time())}")
|
|
221
|
-
|
|
222
|
-
workdir_str = str(spec.workdir)
|
|
223
|
-
|
|
224
|
-
@app.function(
|
|
225
|
-
image=image,
|
|
226
|
-
timeout=60 * 60,
|
|
227
|
-
secrets=[secret_obj] if secret_obj else [],
|
|
228
|
-
)
|
|
229
|
-
def run_task(command: str, patterns: list[str], artifacts_dir: str = "/app/artifacts") -> Dict:
|
|
230
|
-
"""Execute the agent and pull artifacts matching patterns."""
|
|
231
|
-
import glob
|
|
232
|
-
import os
|
|
233
|
-
import subprocess
|
|
234
|
-
|
|
235
|
-
result = subprocess.run(
|
|
236
|
-
["bash", "-lc", command],
|
|
237
|
-
capture_output=True,
|
|
238
|
-
text=True,
|
|
239
|
-
cwd=workdir_str,
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
artifacts: Dict[str, str] = {}
|
|
243
|
-
for pat in patterns:
|
|
244
|
-
for path in glob.glob(os.path.join(artifacts_dir, pat)):
|
|
245
|
-
if not os.path.isfile(path):
|
|
246
|
-
continue
|
|
247
|
-
name = os.path.basename(path)
|
|
248
|
-
with open(path, "rb") as f:
|
|
249
|
-
artifacts[name] = base64.b64encode(f.read()).decode()
|
|
250
|
-
|
|
251
|
-
return {
|
|
252
|
-
"exit_code": result.returncode,
|
|
253
|
-
"stdout": result.stdout,
|
|
254
|
-
"stderr": result.stderr,
|
|
255
|
-
"artifacts": artifacts,
|
|
256
|
-
}
|
|
257
|
-
|
|
258
|
-
container_id = str(uuid.uuid4())
|
|
259
|
-
self._runs[container_id] = {
|
|
260
|
-
"app": app,
|
|
261
|
-
"function": run_task,
|
|
262
|
-
"result": None,
|
|
263
|
-
"ctx_dir": ctx_dir,
|
|
264
|
-
"patterns": tuple(spec.result_matchers()),
|
|
265
|
-
}
|
|
266
|
-
return container_id
|
|
267
|
-
|
|
268
|
-
async def execute(
|
|
269
|
-
self,
|
|
270
|
-
container_id: str,
|
|
271
|
-
command: str,
|
|
272
|
-
*,
|
|
273
|
-
env: Dict[str, str] | None = None,
|
|
274
|
-
workdir: Path | None = None,
|
|
275
|
-
) -> Dict[str, str | int]:
|
|
276
|
-
run_info = self._runs.get(container_id)
|
|
277
|
-
if not run_info:
|
|
278
|
-
raise ValueError(f"Unknown container_id: {container_id}")
|
|
279
|
-
app = run_info["app"]
|
|
280
|
-
run_fn = run_info["function"]
|
|
281
|
-
patterns = list(run_info["patterns"])
|
|
282
|
-
|
|
283
|
-
loop = asyncio.get_running_loop()
|
|
284
|
-
|
|
285
|
-
def _call():
|
|
286
|
-
with app.run():
|
|
287
|
-
return run_fn.call(command, patterns)
|
|
288
|
-
|
|
289
|
-
result = await loop.run_in_executor(None, _call)
|
|
290
|
-
run_info["result"] = result
|
|
291
|
-
return {
|
|
292
|
-
"exit_code": result.get("exit_code", -1),
|
|
293
|
-
"stdout": result.get("stdout", ""),
|
|
294
|
-
"stderr": result.get("stderr", ""),
|
|
295
|
-
}
|
|
296
|
-
|
|
297
|
-
async def collect_artifacts(self, container_id: str, patterns: Iterable[str]) -> Dict[str, bytes]:
|
|
298
|
-
run_info = self._runs.get(container_id)
|
|
299
|
-
if not run_info:
|
|
300
|
-
return {}
|
|
301
|
-
result = run_info.get("result") or {}
|
|
302
|
-
artifacts: Dict[str, bytes] = {}
|
|
303
|
-
encoded = result.get("artifacts") or {} # type: ignore[misc]
|
|
304
|
-
for name, b64 in encoded.items():
|
|
305
|
-
try:
|
|
306
|
-
artifacts[name] = base64.b64decode(b64)
|
|
307
|
-
except Exception:
|
|
308
|
-
continue
|
|
309
|
-
return artifacts
|
|
310
|
-
|
|
311
|
-
async def destroy(self, container_id: str) -> None:
|
|
312
|
-
info = self._runs.pop(container_id, None)
|
|
313
|
-
if not info:
|
|
314
|
-
return
|
|
315
|
-
ctx_dir = info.get("ctx_dir")
|
|
316
|
-
if ctx_dir and hasattr(ctx_dir, "cleanup"):
|
|
317
|
-
with contextlib.suppress(Exception):
|
|
318
|
-
ctx_dir.cleanup() # type: ignore[call-arg]
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
def get_backend(name: str = DEFAULT_BACKEND) -> ContainerBackend:
|
|
322
|
-
"""Resolve backend by name."""
|
|
323
|
-
normalized = (name or DEFAULT_BACKEND).lower()
|
|
324
|
-
if normalized == "docker":
|
|
325
|
-
return DockerBackend()
|
|
326
|
-
if normalized == "modal":
|
|
327
|
-
return ModalBackend()
|
|
328
|
-
raise ValueError(f"Unsupported container backend: {name}")
|