synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (153) hide show
  1. synth_ai/__init__.py +13 -13
  2. synth_ai/cli/__init__.py +6 -15
  3. synth_ai/cli/commands/eval/__init__.py +6 -15
  4. synth_ai/cli/commands/eval/config.py +338 -0
  5. synth_ai/cli/commands/eval/core.py +236 -1091
  6. synth_ai/cli/commands/eval/runner.py +704 -0
  7. synth_ai/cli/commands/eval/validation.py +44 -117
  8. synth_ai/cli/commands/filter/core.py +7 -7
  9. synth_ai/cli/commands/filter/validation.py +2 -2
  10. synth_ai/cli/commands/smoke/core.py +7 -17
  11. synth_ai/cli/commands/status/__init__.py +1 -64
  12. synth_ai/cli/commands/status/client.py +50 -151
  13. synth_ai/cli/commands/status/config.py +3 -83
  14. synth_ai/cli/commands/status/errors.py +4 -13
  15. synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
  16. synth_ai/cli/commands/status/subcommands/config.py +13 -0
  17. synth_ai/cli/commands/status/subcommands/files.py +18 -63
  18. synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
  19. synth_ai/cli/commands/status/subcommands/models.py +18 -62
  20. synth_ai/cli/commands/status/subcommands/runs.py +16 -63
  21. synth_ai/cli/commands/status/subcommands/session.py +67 -172
  22. synth_ai/cli/commands/status/subcommands/summary.py +24 -32
  23. synth_ai/cli/commands/status/subcommands/utils.py +41 -0
  24. synth_ai/cli/commands/status/utils.py +16 -107
  25. synth_ai/cli/commands/train/__init__.py +18 -20
  26. synth_ai/cli/commands/train/errors.py +3 -3
  27. synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
  28. synth_ai/cli/commands/train/validation.py +7 -7
  29. synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
  30. synth_ai/cli/commands/train/verifier_validation.py +235 -0
  31. synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
  32. synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
  33. synth_ai/cli/demo_apps/math/config.toml +0 -1
  34. synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
  35. synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
  36. synth_ai/cli/lib/apps/task_app.py +12 -13
  37. synth_ai/cli/lib/task_app_discovery.py +6 -6
  38. synth_ai/cli/lib/train_cfgs.py +10 -10
  39. synth_ai/cli/task_apps/__init__.py +11 -0
  40. synth_ai/cli/task_apps/commands.py +7 -15
  41. synth_ai/core/env.py +12 -1
  42. synth_ai/core/errors.py +1 -2
  43. synth_ai/core/integrations/cloudflare.py +209 -33
  44. synth_ai/core/tracing_v3/abstractions.py +46 -0
  45. synth_ai/data/__init__.py +3 -30
  46. synth_ai/data/enums.py +1 -20
  47. synth_ai/data/rewards.py +100 -3
  48. synth_ai/products/graph_evolve/__init__.py +1 -2
  49. synth_ai/products/graph_evolve/config.py +16 -16
  50. synth_ai/products/graph_evolve/converters/__init__.py +3 -3
  51. synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
  52. synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
  53. synth_ai/products/graph_gepa/__init__.py +23 -0
  54. synth_ai/products/graph_gepa/converters/__init__.py +19 -0
  55. synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
  56. synth_ai/sdk/__init__.py +45 -35
  57. synth_ai/sdk/api/eval/__init__.py +33 -0
  58. synth_ai/sdk/api/eval/job.py +732 -0
  59. synth_ai/sdk/api/research_agent/__init__.py +276 -66
  60. synth_ai/sdk/api/train/builders.py +181 -0
  61. synth_ai/sdk/api/train/cli.py +41 -33
  62. synth_ai/sdk/api/train/configs/__init__.py +6 -4
  63. synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
  64. synth_ai/sdk/api/train/configs/rl.py +264 -16
  65. synth_ai/sdk/api/train/configs/sft.py +165 -1
  66. synth_ai/sdk/api/train/graph_validators.py +12 -12
  67. synth_ai/sdk/api/train/graphgen.py +169 -51
  68. synth_ai/sdk/api/train/graphgen_models.py +95 -45
  69. synth_ai/sdk/api/train/local_api.py +10 -0
  70. synth_ai/sdk/api/train/pollers.py +36 -0
  71. synth_ai/sdk/api/train/prompt_learning.py +390 -60
  72. synth_ai/sdk/api/train/rl.py +41 -5
  73. synth_ai/sdk/api/train/sft.py +2 -0
  74. synth_ai/sdk/api/train/task_app.py +20 -0
  75. synth_ai/sdk/api/train/validators.py +17 -17
  76. synth_ai/sdk/graphs/completions.py +239 -33
  77. synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
  78. synth_ai/sdk/learning/__init__.py +35 -5
  79. synth_ai/sdk/learning/context_learning_client.py +531 -0
  80. synth_ai/sdk/learning/context_learning_types.py +294 -0
  81. synth_ai/sdk/learning/prompt_learning_client.py +1 -1
  82. synth_ai/sdk/learning/prompt_learning_types.py +2 -1
  83. synth_ai/sdk/learning/rl/__init__.py +0 -4
  84. synth_ai/sdk/learning/rl/contracts.py +0 -4
  85. synth_ai/sdk/localapi/__init__.py +40 -0
  86. synth_ai/sdk/localapi/apps/__init__.py +28 -0
  87. synth_ai/sdk/localapi/client.py +10 -0
  88. synth_ai/sdk/localapi/contracts.py +10 -0
  89. synth_ai/sdk/localapi/helpers.py +519 -0
  90. synth_ai/sdk/localapi/rollouts.py +93 -0
  91. synth_ai/sdk/localapi/server.py +29 -0
  92. synth_ai/sdk/localapi/template.py +49 -0
  93. synth_ai/sdk/streaming/handlers.py +6 -6
  94. synth_ai/sdk/streaming/streamer.py +10 -6
  95. synth_ai/sdk/task/__init__.py +18 -5
  96. synth_ai/sdk/task/apps/__init__.py +37 -1
  97. synth_ai/sdk/task/client.py +9 -1
  98. synth_ai/sdk/task/config.py +6 -11
  99. synth_ai/sdk/task/contracts.py +137 -95
  100. synth_ai/sdk/task/in_process.py +32 -22
  101. synth_ai/sdk/task/in_process_runner.py +9 -4
  102. synth_ai/sdk/task/rubrics/__init__.py +2 -3
  103. synth_ai/sdk/task/rubrics/loaders.py +4 -4
  104. synth_ai/sdk/task/rubrics/strict.py +3 -4
  105. synth_ai/sdk/task/server.py +76 -16
  106. synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
  107. synth_ai/sdk/task/validators.py +34 -49
  108. synth_ai/sdk/training/__init__.py +7 -16
  109. synth_ai/sdk/tunnels/__init__.py +118 -0
  110. synth_ai/sdk/tunnels/cleanup.py +83 -0
  111. synth_ai/sdk/tunnels/ports.py +120 -0
  112. synth_ai/sdk/tunnels/tunneled_api.py +363 -0
  113. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
  114. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
  115. synth_ai/cli/commands/baseline/__init__.py +0 -12
  116. synth_ai/cli/commands/baseline/core.py +0 -636
  117. synth_ai/cli/commands/baseline/list.py +0 -94
  118. synth_ai/cli/commands/eval/errors.py +0 -81
  119. synth_ai/cli/commands/status/formatters.py +0 -164
  120. synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
  121. synth_ai/cli/commands/status/subcommands/usage.py +0 -203
  122. synth_ai/cli/commands/train/judge_validation.py +0 -305
  123. synth_ai/cli/usage.py +0 -159
  124. synth_ai/data/specs.py +0 -36
  125. synth_ai/sdk/api/research_agent/cli.py +0 -428
  126. synth_ai/sdk/api/research_agent/config.py +0 -357
  127. synth_ai/sdk/api/research_agent/job.py +0 -717
  128. synth_ai/sdk/baseline/__init__.py +0 -25
  129. synth_ai/sdk/baseline/config.py +0 -209
  130. synth_ai/sdk/baseline/discovery.py +0 -216
  131. synth_ai/sdk/baseline/execution.py +0 -154
  132. synth_ai/sdk/judging/__init__.py +0 -15
  133. synth_ai/sdk/judging/base.py +0 -24
  134. synth_ai/sdk/judging/client.py +0 -191
  135. synth_ai/sdk/judging/types.py +0 -42
  136. synth_ai/sdk/research_agent/__init__.py +0 -34
  137. synth_ai/sdk/research_agent/container_builder.py +0 -328
  138. synth_ai/sdk/research_agent/container_spec.py +0 -198
  139. synth_ai/sdk/research_agent/defaults.py +0 -34
  140. synth_ai/sdk/research_agent/results_collector.py +0 -69
  141. synth_ai/sdk/specs/__init__.py +0 -46
  142. synth_ai/sdk/specs/dataclasses.py +0 -149
  143. synth_ai/sdk/specs/loader.py +0 -144
  144. synth_ai/sdk/specs/serializer.py +0 -199
  145. synth_ai/sdk/specs/validation.py +0 -250
  146. synth_ai/sdk/tracing/__init__.py +0 -39
  147. synth_ai/sdk/usage/__init__.py +0 -37
  148. synth_ai/sdk/usage/client.py +0 -171
  149. synth_ai/sdk/usage/models.py +0 -261
  150. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
  151. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
  152. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
  153. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
@@ -1,191 +0,0 @@
1
- """Experimental Judge API client.
2
-
3
- This surface is experimental and subject to change without notice.
4
- Set environment variable `SYNTH_SILENCE_EXPERIMENTAL=1` to silence warnings.
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import os
10
- import warnings
11
- from typing import Any, Literal, TypedDict
12
-
13
- from synth_ai.core.http import AsyncHttpClient, HTTPError
14
- from synth_ai.core.tracing_v3.serialization import normalize_for_json
15
- from synth_ai.sdk.graphs import VerifierClient as GraphVerifierClient
16
-
17
- Provider = Literal["groq", "gemini"]
18
-
19
-
20
- class JudgeOptions(TypedDict, total=False):
21
- event: bool
22
- outcome: bool
23
- rubric_id: str
24
- rubric_overrides: dict[str, Any]
25
- provider: Provider
26
- model: str
27
- max_concurrency: int
28
- verifier_type: str
29
-
30
-
31
- class JudgeScoreResponse(TypedDict, total=False):
32
- status: str
33
- event_rewards: list[dict[str, Any]]
34
- outcome_reward: dict[str, Any]
35
- details: dict[str, Any]
36
-
37
-
38
- class JudgeClient:
39
- """Legacy client for LLM-based evaluation of task app traces.
40
-
41
- This client provides programmatic access to Synth AI's judge API, which uses
42
- LLMs to evaluate task execution traces and generate rewards. The judge can
43
- evaluate both event-level (step-by-step) and outcome-level (episode-level) rewards.
44
-
45
- .. warning::
46
- This API is experimental and subject to change without notice.
47
- Set `SYNTH_SILENCE_EXPERIMENTAL=1` to silence warnings.
48
-
49
- Example:
50
- >>> from synth_ai.sdk.judging import JudgeClient, JudgeOptions
51
- >>>
52
- >>> client = JudgeClient(
53
- ... base_url="https://api.usesynth.ai",
54
- ... api_key=os.environ["SYNTH_API_KEY"],
55
- ... )
56
- >>>
57
- >>> # Score a trace with outcome reward
58
- >>> result = await client.score(
59
- ... trace=my_trace_dict,
60
- ... policy_name="my_policy",
61
- ... task_app_id="heartdisease",
62
- ... options=JudgeOptions(
63
- ... outcome=True,
64
- ... rubric_id="accuracy",
65
- ... provider="groq",
66
- ... model="llama-3.1-8b-instant",
67
- ... ),
68
- ... )
69
- >>>
70
- >>> print(f"Outcome reward: {result['outcome_reward']}")
71
- """
72
-
73
- def __init__(self, base_url: str, api_key: str, *, timeout: float = 60.0) -> None:
74
- """Initialize the judge client.
75
-
76
- Args:
77
- base_url: Base URL for the Synth AI API
78
- api_key: API key for authentication
79
- timeout: Request timeout in seconds (default: 60.0)
80
- """
81
- _silence = (os.getenv("SYNTH_SILENCE_EXPERIMENTAL") or "").strip().lower()
82
- if _silence not in {"1", "true", "t", "yes", "y", "on"}:
83
- warnings.warn(
84
- "Legacy API: synth_ai.sdk.judging.JudgeClient is legacy. "
85
- "Use synth_ai.sdk.graphs.VerifierClient or GraphCompletionsClient instead.",
86
- UserWarning,
87
- stacklevel=2,
88
- )
89
- self._base = base_url.rstrip("/")
90
- self._key = api_key
91
- self._timeout = timeout
92
-
93
- async def score(
94
- self,
95
- *,
96
- trace: dict[str, Any] | Any,
97
- policy_name: str,
98
- task_app_id: str,
99
- options: JudgeOptions,
100
- rubric: dict[str, Any] | None = None,
101
- verifier_type: str | None = None,
102
- task_app_base_url: str | None = None,
103
- ) -> JudgeScoreResponse:
104
- """Score a task execution trace using LLM-based evaluation.
105
-
106
- This method sends a trace to the judge API, which evaluates it according
107
- to the provided rubric and returns event-level and/or outcome-level rewards.
108
-
109
- Args:
110
- trace: Task execution trace (SessionTrace dict or compatible object)
111
- policy_name: Name of the policy that generated this trace
112
- task_app_id: Identifier for the task app (e.g., "heartdisease")
113
- options: Judge configuration options:
114
- - event: Whether to generate event-level rewards (default: False)
115
- - outcome: Whether to generate outcome-level reward (default: False)
116
- - rubric_id: Rubric identifier to use for evaluation
117
- - rubric_overrides: Optional rubric modifications
118
- - provider: LLM provider ("groq" or "gemini")
119
- - model: Model identifier (e.g., "llama-3.1-8b-instant")
120
- - max_concurrency: Max concurrent judge calls (default: 1)
121
- rubric: Optional explicit rubric criteria (event/outcome lists)
122
- verifier_type: Optional zero-shot verifier graph ID (e.g., "zero_shot_verifier_single")
123
- task_app_base_url: Optional base URL for task app (for rubric fetching)
124
-
125
- Returns:
126
- JudgeScoreResponse with:
127
- - status: "ok" or error status
128
- - event_rewards: List of event-level reward dicts (if event=True)
129
- - outcome_reward: Outcome-level reward dict (if outcome=True)
130
- - details: Additional evaluation details
131
-
132
- Raises:
133
- ValueError: If validation fails or rubric is invalid
134
- PermissionError: If authentication fails
135
- FileNotFoundError: If task app or rubric not found
136
- Exception: For rate limiting or transient errors
137
- """
138
- trace_payload = normalize_for_json(trace)
139
- task_app_payload = {"id": task_app_id}
140
- if task_app_base_url:
141
- task_app_payload["base_url"] = task_app_base_url
142
-
143
- selected_verifier = verifier_type or (options or {}).get("verifier_type")
144
- if selected_verifier:
145
- graph_input = {
146
- "policy_name": policy_name,
147
- "task_app": task_app_payload,
148
- "session_trace": trace_payload,
149
- "trace": trace_payload,
150
- "options": options or {},
151
- }
152
- if rubric is not None:
153
- graph_input["rubric"] = normalize_for_json(rubric)
154
- body = {"job_id": selected_verifier, "input": graph_input}
155
- else:
156
- body = {
157
- "policy_name": policy_name,
158
- "task_app": task_app_payload,
159
- "trace": trace_payload,
160
- "options": options or {},
161
- }
162
- if rubric is not None:
163
- body["rubric"] = normalize_for_json(rubric)
164
- try:
165
- async with AsyncHttpClient(self._base, self._key, timeout=self._timeout) as http:
166
- if selected_verifier:
167
- js = await http.post_json("/api/graphs/completions", json=body)
168
- if isinstance(js, dict) and "output" in js:
169
- js = js["output"]
170
- else:
171
- js = await http.post_json("/api/judge/v1/score", json=body)
172
- if not isinstance(js, dict):
173
- raise ValueError("invalid_judge_response_shape")
174
- return js # type: ignore[return-value]
175
- except HTTPError as err: # map to friendlier exceptions
176
- status = int(getattr(err, "status", 0) or 0)
177
- if status in (400, 422):
178
- raise ValueError(f"judge_validation_error: {err.detail}") from err
179
- if status in (401, 403):
180
- raise PermissionError(f"judge_auth_error: {err.detail}") from err
181
- if status == 404:
182
- raise FileNotFoundError(f"judge_route_not_found: {err.detail}") from err
183
- if status == 429:
184
- raise Exception("judge_rate_limited") from err # replace with RetryLater in future
185
- if status >= 500:
186
- raise Exception("judge_transient_error") from err # replace with TransientError in future
187
- raise
188
-
189
-
190
- class VerifierClient(GraphVerifierClient):
191
- """Deprecated alias for graph-based VerifierClient."""
@@ -1,42 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Literal, TypedDict
4
-
5
- Track = Literal["process", "reasoning", "progress", "outcome"]
6
-
7
-
8
- class Judgement(TypedDict, total=False):
9
- key: str
10
- title: str
11
- description: str
12
- score: float
13
- reason: str
14
- confidence: float
15
- scale: Literal["binary", "bounded", "count", "custom"]
16
- source: dict
17
-
18
-
19
- class RewardJudgement(TypedDict, total=False):
20
- judgement: Judgement
21
- scope: Literal["step", "event", "outcome"]
22
- turn: int | None
23
- episode_id: str | None
24
- reward_value: float | None
25
- links: dict
26
-
27
-
28
- class TrackAggregate(TypedDict, total=False):
29
- mean: float
30
- median: float
31
- std: float
32
- n: int
33
-
34
-
35
- class RewardMetadata(TypedDict, total=False):
36
- per_window: list[RewardJudgement]
37
- aggregates: dict[Track, TrackAggregate]
38
- overall: dict[str, float] # {"final_outcome_score": float}
39
- rubric: dict # {"ids": {...}, "hash": "..."}
40
- model_info: dict # {"model": "...", ...}
41
-
42
-
@@ -1,34 +0,0 @@
1
- from synth_ai.sdk.research_agent.container_builder import (
2
- ContainerBackend,
3
- DockerBackend,
4
- ModalBackend,
5
- get_backend,
6
- )
7
- from synth_ai.sdk.research_agent.container_spec import ContainerSpec
8
- from synth_ai.sdk.research_agent.defaults import (
9
- DEFAULT_BACKEND,
10
- DEFAULT_BASE_IMAGE,
11
- DEFAULT_INSTRUCTIONS,
12
- DEFAULT_PACKAGES,
13
- DEFAULT_PYTHON_VERSION,
14
- DEFAULT_REASONING_EFFORT,
15
- DEFAULT_RESULT_PATTERNS,
16
- )
17
- from synth_ai.sdk.research_agent.results_collector import ResultsCollector
18
-
19
- __all__ = [
20
- "ContainerBackend",
21
- "ContainerSpec",
22
- "DockerBackend",
23
- "ModalBackend",
24
- "ResultsCollector",
25
- "get_backend",
26
- "DEFAULT_BACKEND",
27
- "DEFAULT_BASE_IMAGE",
28
- "DEFAULT_INSTRUCTIONS",
29
- "DEFAULT_PACKAGES",
30
- "DEFAULT_PYTHON_VERSION",
31
- "DEFAULT_REASONING_EFFORT",
32
- "DEFAULT_RESULT_PATTERNS",
33
- ]
34
-
@@ -1,328 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import asyncio
4
- import base64
5
- import contextlib
6
- import fnmatch
7
- import io
8
- import tarfile
9
- import tempfile
10
- import time
11
- import uuid
12
- from abc import ABC, abstractmethod
13
- from pathlib import Path
14
- from typing import Dict, Iterable, Tuple
15
-
16
- from synth_ai.sdk.research_agent.container_spec import ContainerSpec
17
- from synth_ai.sdk.research_agent.defaults import DEFAULT_BACKEND
18
-
19
-
20
- class ContainerBackend(ABC):
21
- """Abstract base for container execution backends."""
22
-
23
- @abstractmethod
24
- async def provision(self, spec: ContainerSpec) -> str:
25
- """Provision a new container and return its id/handle."""
26
-
27
- @abstractmethod
28
- async def execute(
29
- self,
30
- container_id: str,
31
- command: str,
32
- *,
33
- env: Dict[str, str] | None = None,
34
- workdir: Path | None = None,
35
- ) -> Dict[str, str | int]:
36
- """Execute a command in the container."""
37
-
38
- @abstractmethod
39
- async def collect_artifacts(self, container_id: str, patterns: Iterable[str]) -> Dict[str, bytes]:
40
- """Pull artifacts that match any of the glob patterns."""
41
-
42
- @abstractmethod
43
- async def destroy(self, container_id: str) -> None:
44
- """Tear down container resources."""
45
-
46
-
47
- class DockerBackend(ContainerBackend):
48
- """Docker implementation using docker-py."""
49
-
50
- def __init__(self, *, client=None):
51
- self._client = client
52
- self._containers: Dict[str, Tuple[object, ContainerSpec]] = {}
53
-
54
- def _ensure_client(self):
55
- if self._client is None:
56
- try:
57
- import docker # type: ignore
58
- except ImportError as exc:
59
- raise RuntimeError("docker SDK is not installed. Add docker>=7.0.0 to dependencies.") from exc
60
- self._client = docker.from_env()
61
- return self._client
62
-
63
- async def provision(self, spec: ContainerSpec) -> str:
64
- spec.validate()
65
- client = self._ensure_client()
66
- context_bytes = spec.build_context()
67
- image_tag = f"research-agent:{int(time.time())}"
68
-
69
- def _build():
70
- return client.images.build(
71
- fileobj=io.BytesIO(context_bytes),
72
- custom_context=True,
73
- rm=True,
74
- nocache=True,
75
- tag=image_tag,
76
- buildargs=spec.build_args,
77
- )
78
-
79
- loop = asyncio.get_running_loop()
80
- image, _ = await loop.run_in_executor(None, _build)
81
-
82
- container = client.containers.create(
83
- image=image.id,
84
- command="sleep infinity",
85
- environment={**spec.env_vars, **spec.secrets},
86
- tty=True,
87
- detach=True,
88
- working_dir=str(spec.workdir),
89
- )
90
- container.start()
91
- self._containers[container.id] = (container, spec)
92
- return container.id
93
-
94
- async def execute(
95
- self,
96
- container_id: str,
97
- command: str,
98
- *,
99
- env: Dict[str, str] | None = None,
100
- workdir: Path | None = None,
101
- ) -> Dict[str, str | int]:
102
- container, spec = self._containers[container_id]
103
- if env is None:
104
- env = {}
105
- exec_env = {**spec.env_vars, **spec.secrets, **env}
106
- workdir_str = str(workdir or spec.workdir)
107
-
108
- def _run():
109
- result = container.exec_run( # type: ignore[attr-defined]
110
- cmd=["bash", "-lc", command],
111
- environment=exec_env,
112
- workdir=workdir_str,
113
- demux=True,
114
- )
115
- stdout, stderr = result.output
116
- return {
117
- "exit_code": result.exit_code,
118
- "stdout": (stdout or b"").decode(),
119
- "stderr": (stderr or b"").decode(),
120
- }
121
-
122
- loop = asyncio.get_running_loop()
123
- return await loop.run_in_executor(None, _run)
124
-
125
- async def collect_artifacts(self, container_id: str, patterns: Iterable[str]) -> Dict[str, bytes]:
126
- container, spec = self._containers[container_id]
127
- loop = asyncio.get_running_loop()
128
-
129
- def _pull():
130
- try:
131
- stream, _ = container.get_archive(str(spec.artifacts_dir)) # type: ignore[attr-defined]
132
- except Exception:
133
- return {}
134
- tar_bytes = b"".join(stream)
135
- collected: Dict[str, bytes] = {}
136
- with tarfile.open(fileobj=io.BytesIO(tar_bytes), mode="r:*") as tar:
137
- for member in tar.getmembers():
138
- if not member.isfile():
139
- continue
140
- relative_name = str(Path(member.name).name)
141
- if not any(fnmatch.fnmatch(relative_name, pat) for pat in patterns):
142
- continue
143
- file_obj = tar.extractfile(member)
144
- if file_obj:
145
- collected[relative_name] = file_obj.read()
146
- return collected
147
-
148
- return await loop.run_in_executor(None, _pull)
149
-
150
- async def destroy(self, container_id: str) -> None:
151
- container, _ = self._containers.pop(container_id, (None, None))
152
- if container is None:
153
- return
154
-
155
- def _stop():
156
- with contextlib.suppress(Exception):
157
- container.kill() # type: ignore[attr-defined]
158
- with contextlib.suppress(Exception):
159
- container.remove(force=True) # type: ignore[attr-defined]
160
-
161
- loop = asyncio.get_running_loop()
162
- await loop.run_in_executor(None, _stop)
163
-
164
-
165
- class ModalBackend(ContainerBackend):
166
- """Modal implementation using modal SDK. Returns artifacts inline from execute()."""
167
-
168
- def __init__(self):
169
- self._runs: Dict[str, Dict[str, object]] = {}
170
-
171
- def _write_build_context(self, spec: ContainerSpec) -> tempfile.TemporaryDirectory:
172
- """Materialize Dockerfile + overlay files to a temp dir for Modal build."""
173
- temp_dir = tempfile.TemporaryDirectory()
174
- ctx = Path(temp_dir.name)
175
- (ctx / "Dockerfile").write_text(spec.to_dockerfile())
176
-
177
- overlay_root = ctx / "overlay_files"
178
- for rel_path, content in spec.rendered_overlay_files().items():
179
- target = overlay_root / rel_path
180
- target.parent.mkdir(parents=True, exist_ok=True)
181
- target.write_bytes(content)
182
-
183
- for rel_path, content in spec.files.items():
184
- if not str(rel_path).startswith("/"):
185
- continue
186
- data = content.encode() if isinstance(content, str) else content
187
- target = ctx / str(rel_path).lstrip("/")
188
- target.parent.mkdir(parents=True, exist_ok=True)
189
- target.write_bytes(data)
190
-
191
- return temp_dir
192
-
193
- async def provision(self, spec: ContainerSpec) -> str:
194
- spec.validate()
195
- try:
196
- import modal # type: ignore
197
- except ImportError as exc: # pragma: no cover - runtime import guard
198
- raise RuntimeError("modal SDK is not installed. Add modal>=1.1.1 to dependencies.") from exc
199
-
200
- ctx_dir = self._write_build_context(spec)
201
- loop = asyncio.get_running_loop()
202
-
203
- def _build_image():
204
- return modal.Image.from_dockerfile(
205
- path=ctx_dir.name,
206
- build_args=spec.build_args,
207
- force_build=True,
208
- )
209
-
210
- image = await loop.run_in_executor(None, _build_image)
211
-
212
- # Combine env_vars and secrets into a Modal Secret
213
- # Modal function decorator doesn't accept 'env' parameter directly
214
- # Environment variables must be passed via secrets
215
- combined_env: dict[str, str | None] = {**spec.env_vars, **spec.secrets}
216
- secret_obj = None
217
- if combined_env:
218
- secret_obj = modal.Secret.from_dict(combined_env)
219
-
220
- app = modal.App(f"oneshot-research-{int(time.time())}")
221
-
222
- workdir_str = str(spec.workdir)
223
-
224
- @app.function(
225
- image=image,
226
- timeout=60 * 60,
227
- secrets=[secret_obj] if secret_obj else [],
228
- )
229
- def run_task(command: str, patterns: list[str], artifacts_dir: str = "/app/artifacts") -> Dict:
230
- """Execute the agent and pull artifacts matching patterns."""
231
- import glob
232
- import os
233
- import subprocess
234
-
235
- result = subprocess.run(
236
- ["bash", "-lc", command],
237
- capture_output=True,
238
- text=True,
239
- cwd=workdir_str,
240
- )
241
-
242
- artifacts: Dict[str, str] = {}
243
- for pat in patterns:
244
- for path in glob.glob(os.path.join(artifacts_dir, pat)):
245
- if not os.path.isfile(path):
246
- continue
247
- name = os.path.basename(path)
248
- with open(path, "rb") as f:
249
- artifacts[name] = base64.b64encode(f.read()).decode()
250
-
251
- return {
252
- "exit_code": result.returncode,
253
- "stdout": result.stdout,
254
- "stderr": result.stderr,
255
- "artifacts": artifacts,
256
- }
257
-
258
- container_id = str(uuid.uuid4())
259
- self._runs[container_id] = {
260
- "app": app,
261
- "function": run_task,
262
- "result": None,
263
- "ctx_dir": ctx_dir,
264
- "patterns": tuple(spec.result_matchers()),
265
- }
266
- return container_id
267
-
268
- async def execute(
269
- self,
270
- container_id: str,
271
- command: str,
272
- *,
273
- env: Dict[str, str] | None = None,
274
- workdir: Path | None = None,
275
- ) -> Dict[str, str | int]:
276
- run_info = self._runs.get(container_id)
277
- if not run_info:
278
- raise ValueError(f"Unknown container_id: {container_id}")
279
- app = run_info["app"]
280
- run_fn = run_info["function"]
281
- patterns = list(run_info["patterns"])
282
-
283
- loop = asyncio.get_running_loop()
284
-
285
- def _call():
286
- with app.run():
287
- return run_fn.call(command, patterns)
288
-
289
- result = await loop.run_in_executor(None, _call)
290
- run_info["result"] = result
291
- return {
292
- "exit_code": result.get("exit_code", -1),
293
- "stdout": result.get("stdout", ""),
294
- "stderr": result.get("stderr", ""),
295
- }
296
-
297
- async def collect_artifacts(self, container_id: str, patterns: Iterable[str]) -> Dict[str, bytes]:
298
- run_info = self._runs.get(container_id)
299
- if not run_info:
300
- return {}
301
- result = run_info.get("result") or {}
302
- artifacts: Dict[str, bytes] = {}
303
- encoded = result.get("artifacts") or {} # type: ignore[misc]
304
- for name, b64 in encoded.items():
305
- try:
306
- artifacts[name] = base64.b64decode(b64)
307
- except Exception:
308
- continue
309
- return artifacts
310
-
311
- async def destroy(self, container_id: str) -> None:
312
- info = self._runs.pop(container_id, None)
313
- if not info:
314
- return
315
- ctx_dir = info.get("ctx_dir")
316
- if ctx_dir and hasattr(ctx_dir, "cleanup"):
317
- with contextlib.suppress(Exception):
318
- ctx_dir.cleanup() # type: ignore[call-arg]
319
-
320
-
321
- def get_backend(name: str = DEFAULT_BACKEND) -> ContainerBackend:
322
- """Resolve backend by name."""
323
- normalized = (name or DEFAULT_BACKEND).lower()
324
- if normalized == "docker":
325
- return DockerBackend()
326
- if normalized == "modal":
327
- return ModalBackend()
328
- raise ValueError(f"Unsupported container backend: {name}")