agent-runtime-kit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_runtime_kit/__init__.py +72 -0
- agent_runtime_kit/_errors.py +34 -0
- agent_runtime_kit/_runtime.py +139 -0
- agent_runtime_kit/_types.py +251 -0
- agent_runtime_kit/adapters/__init__.py +26 -0
- agent_runtime_kit/adapters/_common.py +123 -0
- agent_runtime_kit/adapters/antigravity.py +379 -0
- agent_runtime_kit/adapters/claude.py +302 -0
- agent_runtime_kit/adapters/codex.py +298 -0
- agent_runtime_kit/adapters/diagnostics.py +18 -0
- agent_runtime_kit/events.py +224 -0
- agent_runtime_kit/py.typed +1 -0
- agent_runtime_kit/registry.py +83 -0
- agent_runtime_kit/testing/__init__.py +15 -0
- agent_runtime_kit/testing/fakes.py +164 -0
- agent_runtime_kit-0.1.0.dist-info/METADATA +118 -0
- agent_runtime_kit-0.1.0.dist-info/RECORD +19 -0
- agent_runtime_kit-0.1.0.dist-info/WHEEL +4 -0
- agent_runtime_kit-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
"""Google Antigravity SDK runtime adapter."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from collections.abc import Mapping
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from tempfile import gettempdir
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from agent_runtime_kit._errors import UnsupportedTaskInputError
|
|
12
|
+
from agent_runtime_kit._types import (
|
|
13
|
+
AgentCapabilities,
|
|
14
|
+
AgentResult,
|
|
15
|
+
AgentRuntimeKind,
|
|
16
|
+
AgentTask,
|
|
17
|
+
AvailabilityReason,
|
|
18
|
+
FilesystemAccess,
|
|
19
|
+
PermissionMode,
|
|
20
|
+
RuntimeAvailability,
|
|
21
|
+
ToolCallAudit,
|
|
22
|
+
Usage,
|
|
23
|
+
)
|
|
24
|
+
from agent_runtime_kit.adapters._common import (
|
|
25
|
+
ensure_supported_model,
|
|
26
|
+
metadata_str,
|
|
27
|
+
output_schema_from,
|
|
28
|
+
package_availability,
|
|
29
|
+
parse_json_output,
|
|
30
|
+
)
|
|
31
|
+
from agent_runtime_kit.events import (
|
|
32
|
+
output_delta_event,
|
|
33
|
+
safe_emit,
|
|
34
|
+
task_completed_event,
|
|
35
|
+
task_failed_event,
|
|
36
|
+
task_started_event,
|
|
37
|
+
tool_completed_event,
|
|
38
|
+
tool_requested_event,
|
|
39
|
+
vendor_turn_event,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AntigravityAgentRuntime:
|
|
44
|
+
"""Run tasks through Google's ``google-antigravity`` SDK."""
|
|
45
|
+
|
|
46
|
+
kind = AgentRuntimeKind.ANTIGRAVITY_AGENT_SDK
|
|
47
|
+
capabilities = AgentCapabilities(
|
|
48
|
+
mcp_support=True,
|
|
49
|
+
working_directory=True,
|
|
50
|
+
session_resume=True,
|
|
51
|
+
structured_output=True,
|
|
52
|
+
streaming=True,
|
|
53
|
+
tool_audit=True,
|
|
54
|
+
cancellation=False,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
*,
|
|
60
|
+
default_model: str = "gemini-3.5-flash",
|
|
61
|
+
supported_models: tuple[str, ...] | None = None,
|
|
62
|
+
api_key: str | None = None,
|
|
63
|
+
agent_cls: Any | None = None,
|
|
64
|
+
config_cls: Any | None = None,
|
|
65
|
+
types_module: Any | None = None,
|
|
66
|
+
policy_module: Any | None = None,
|
|
67
|
+
) -> None:
|
|
68
|
+
self._default_model = default_model
|
|
69
|
+
self._supported_models = supported_models
|
|
70
|
+
self._api_key = api_key
|
|
71
|
+
self._agent_cls = agent_cls
|
|
72
|
+
self._config_cls = config_cls
|
|
73
|
+
self._types = types_module
|
|
74
|
+
self._policy = policy_module
|
|
75
|
+
|
|
76
|
+
def availability(self) -> RuntimeAvailability:
|
|
77
|
+
"""Report Antigravity package and API-key availability."""
|
|
78
|
+
|
|
79
|
+
if self._agent_cls is not None:
|
|
80
|
+
return RuntimeAvailability.ok(self.kind, package="google-antigravity")
|
|
81
|
+
package = package_availability(
|
|
82
|
+
self.kind,
|
|
83
|
+
module_name="google.antigravity",
|
|
84
|
+
package_name="google-antigravity",
|
|
85
|
+
)
|
|
86
|
+
if not package.available:
|
|
87
|
+
return package
|
|
88
|
+
if self._api_key_value() is None:
|
|
89
|
+
return RuntimeAvailability.unavailable(
|
|
90
|
+
self.kind,
|
|
91
|
+
reason=AvailabilityReason.MISSING_CREDENTIALS,
|
|
92
|
+
message="Set GEMINI_API_KEY or GOOGLE_API_KEY to use Antigravity.",
|
|
93
|
+
package="google-antigravity",
|
|
94
|
+
metadata=package.metadata,
|
|
95
|
+
)
|
|
96
|
+
return package
|
|
97
|
+
|
|
98
|
+
async def run(self, task: AgentTask) -> AgentResult:
|
|
99
|
+
"""Execute one task with Antigravity."""
|
|
100
|
+
|
|
101
|
+
await safe_emit(task, task_started_event(task, self.kind))
|
|
102
|
+
try:
|
|
103
|
+
model = self._model(task)
|
|
104
|
+
ensure_supported_model(
|
|
105
|
+
kind=self.kind,
|
|
106
|
+
model=model,
|
|
107
|
+
supported_models=self._supported_models,
|
|
108
|
+
)
|
|
109
|
+
sdk = self._load_sdk()
|
|
110
|
+
api_key = self._api_key_value()
|
|
111
|
+
if api_key is None:
|
|
112
|
+
raise RuntimeError("Antigravity requires GEMINI_API_KEY or GOOGLE_API_KEY")
|
|
113
|
+
config = self._build_config(task, model=model, api_key=api_key, sdk=sdk)
|
|
114
|
+
result = await self._invoke(task, config=config, sdk=sdk, model=model)
|
|
115
|
+
except Exception as exc:
|
|
116
|
+
await safe_emit(task, task_failed_event(task, self.kind, error=str(exc)))
|
|
117
|
+
raise
|
|
118
|
+
|
|
119
|
+
if result.error:
|
|
120
|
+
await safe_emit(task, task_failed_event(task, self.kind, error=result.error))
|
|
121
|
+
else:
|
|
122
|
+
await safe_emit(task, task_completed_event(task, self.kind, result))
|
|
123
|
+
return result
|
|
124
|
+
|
|
125
|
+
async def cancel(self, task_id: str) -> None:
|
|
126
|
+
"""Antigravity cancellation is not exposed through this portable adapter yet."""
|
|
127
|
+
|
|
128
|
+
del task_id
|
|
129
|
+
|
|
130
|
+
def _load_sdk(self) -> _AntigravitySDK:
|
|
131
|
+
if (
|
|
132
|
+
self._agent_cls is not None
|
|
133
|
+
and self._config_cls is not None
|
|
134
|
+
and self._types is not None
|
|
135
|
+
and self._policy is not None
|
|
136
|
+
):
|
|
137
|
+
return _AntigravitySDK(self._agent_cls, self._config_cls, self._types, self._policy)
|
|
138
|
+
try:
|
|
139
|
+
from google.antigravity import types # type: ignore[import-not-found]
|
|
140
|
+
from google.antigravity.agent import Agent # type: ignore[import-not-found]
|
|
141
|
+
from google.antigravity.connections.local.local_connection_config import ( # type: ignore[import-not-found]
|
|
142
|
+
LocalAgentConfig,
|
|
143
|
+
)
|
|
144
|
+
from google.antigravity.hooks import policy # type: ignore[import-not-found]
|
|
145
|
+
except ImportError as exc:
|
|
146
|
+
raise RuntimeError(
|
|
147
|
+
"google-antigravity is not installed. Install agent-runtime-kit[antigravity]."
|
|
148
|
+
) from exc
|
|
149
|
+
return _AntigravitySDK(Agent, LocalAgentConfig, types, policy)
|
|
150
|
+
|
|
151
|
+
def _build_config(
|
|
152
|
+
self,
|
|
153
|
+
task: AgentTask,
|
|
154
|
+
*,
|
|
155
|
+
model: str,
|
|
156
|
+
api_key: str,
|
|
157
|
+
sdk: _AntigravitySDK,
|
|
158
|
+
) -> Any:
|
|
159
|
+
for server in task.mcp_servers:
|
|
160
|
+
if server.env:
|
|
161
|
+
raise UnsupportedTaskInputError(
|
|
162
|
+
self.kind,
|
|
163
|
+
"mcp_servers.env",
|
|
164
|
+
"Antigravity MCP stdio server config does not support env",
|
|
165
|
+
)
|
|
166
|
+
capabilities, policies = _capability_policy(task, sdk)
|
|
167
|
+
schema = output_schema_from(task.output_schema, task.metadata)
|
|
168
|
+
return sdk.config_cls(
|
|
169
|
+
model=model,
|
|
170
|
+
api_key=api_key,
|
|
171
|
+
system_instructions=task.system,
|
|
172
|
+
capabilities=capabilities,
|
|
173
|
+
policies=policies,
|
|
174
|
+
workspaces=_workspaces(task),
|
|
175
|
+
conversation_id=_conversation_id(task),
|
|
176
|
+
save_dir=str(_runtime_dir("antigravity-sessions")),
|
|
177
|
+
app_data_dir=str(_runtime_dir("antigravity-app-data")),
|
|
178
|
+
response_schema=dict(schema) if schema is not None else None,
|
|
179
|
+
mcp_servers=[
|
|
180
|
+
sdk.types.McpStdioServer(command=server.command, args=list(server.args))
|
|
181
|
+
for server in task.mcp_servers
|
|
182
|
+
],
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
async def _invoke(
|
|
186
|
+
self,
|
|
187
|
+
task: AgentTask,
|
|
188
|
+
*,
|
|
189
|
+
config: Any,
|
|
190
|
+
sdk: _AntigravitySDK,
|
|
191
|
+
model: str,
|
|
192
|
+
) -> AgentResult:
|
|
193
|
+
text_parts: list[str] = []
|
|
194
|
+
tool_calls: list[ToolCallAudit] = []
|
|
195
|
+
usage_metadata: Any | None = None
|
|
196
|
+
structured_output: Any | None = None
|
|
197
|
+
session_id: str | None = None
|
|
198
|
+
|
|
199
|
+
async with sdk.agent_cls(config) as agent:
|
|
200
|
+
response = await agent.chat(task.goal)
|
|
201
|
+
async for chunk in response.chunks:
|
|
202
|
+
await self._consume_chunk(
|
|
203
|
+
task,
|
|
204
|
+
chunk=chunk,
|
|
205
|
+
sdk=sdk,
|
|
206
|
+
text_parts=text_parts,
|
|
207
|
+
tool_calls=tool_calls,
|
|
208
|
+
)
|
|
209
|
+
structured_output = await _maybe_await(response.structured_output())
|
|
210
|
+
usage_metadata = getattr(response, "usage_metadata", None)
|
|
211
|
+
session_id = _optional_str(getattr(agent, "conversation_id", None))
|
|
212
|
+
|
|
213
|
+
output = "".join(text_parts).strip()
|
|
214
|
+
schema = output_schema_from(task.output_schema, task.metadata)
|
|
215
|
+
if structured_output is None and schema is not None:
|
|
216
|
+
structured_output = parse_json_output(output)
|
|
217
|
+
if schema is not None and structured_output is None:
|
|
218
|
+
return AgentResult(
|
|
219
|
+
output=output,
|
|
220
|
+
finish_reason="failed",
|
|
221
|
+
error="Antigravity SDK returned no structured output for output_schema",
|
|
222
|
+
usage=_usage_from(usage_metadata),
|
|
223
|
+
tool_calls=tuple(tool_calls),
|
|
224
|
+
session_id=session_id,
|
|
225
|
+
metadata={"model": model, "sdk": "google_antigravity"},
|
|
226
|
+
)
|
|
227
|
+
return AgentResult(
|
|
228
|
+
output=output,
|
|
229
|
+
parsed_output=structured_output,
|
|
230
|
+
usage=_usage_from(usage_metadata),
|
|
231
|
+
tool_calls=tuple(tool_calls),
|
|
232
|
+
session_id=session_id,
|
|
233
|
+
rounds=1,
|
|
234
|
+
metadata={"model": model, "sdk": "google_antigravity"},
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
async def _consume_chunk(
|
|
238
|
+
self,
|
|
239
|
+
task: AgentTask,
|
|
240
|
+
*,
|
|
241
|
+
chunk: Any,
|
|
242
|
+
sdk: _AntigravitySDK,
|
|
243
|
+
text_parts: list[str],
|
|
244
|
+
tool_calls: list[ToolCallAudit],
|
|
245
|
+
) -> None:
|
|
246
|
+
if isinstance(chunk, sdk.types.Text):
|
|
247
|
+
text_parts.append(str(chunk.text))
|
|
248
|
+
await safe_emit(task, output_delta_event(task, self.kind, text=str(chunk.text)))
|
|
249
|
+
return
|
|
250
|
+
if isinstance(chunk, sdk.types.Thought):
|
|
251
|
+
await safe_emit(
|
|
252
|
+
task,
|
|
253
|
+
vendor_turn_event(
|
|
254
|
+
task,
|
|
255
|
+
self.kind,
|
|
256
|
+
payload={"chunk_type": "Thought", "delta_length": len(str(chunk.text))},
|
|
257
|
+
),
|
|
258
|
+
)
|
|
259
|
+
return
|
|
260
|
+
if isinstance(chunk, sdk.types.ToolCall):
|
|
261
|
+
await safe_emit(
|
|
262
|
+
task,
|
|
263
|
+
tool_requested_event(
|
|
264
|
+
task,
|
|
265
|
+
self.kind,
|
|
266
|
+
tool_name=_tool_name(getattr(chunk, "name", "tool")),
|
|
267
|
+
arguments=_tool_arguments(chunk),
|
|
268
|
+
),
|
|
269
|
+
)
|
|
270
|
+
return
|
|
271
|
+
if isinstance(chunk, sdk.types.ToolResult):
|
|
272
|
+
audit = ToolCallAudit(
|
|
273
|
+
tool_name=_tool_name(getattr(chunk, "name", "tool")),
|
|
274
|
+
arguments=_tool_arguments(chunk),
|
|
275
|
+
result_preview=str(getattr(chunk, "result", ""))[:256],
|
|
276
|
+
status="ok",
|
|
277
|
+
)
|
|
278
|
+
tool_calls.append(audit)
|
|
279
|
+
await safe_emit(task, tool_completed_event(task, self.kind, audit))
|
|
280
|
+
return
|
|
281
|
+
await safe_emit(
|
|
282
|
+
task,
|
|
283
|
+
vendor_turn_event(task, self.kind, payload={"chunk_type": type(chunk).__name__}),
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def _api_key_value(self) -> str | None:
|
|
287
|
+
return self._api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
|
|
288
|
+
|
|
289
|
+
def _model(self, task: AgentTask) -> str:
|
|
290
|
+
return metadata_str(task.metadata, "model") or self._default_model
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class _AntigravitySDK:
|
|
294
|
+
def __init__(self, agent_cls: Any, config_cls: Any, types: Any, policy: Any) -> None:
|
|
295
|
+
self.agent_cls = agent_cls
|
|
296
|
+
self.config_cls = config_cls
|
|
297
|
+
self.types = types
|
|
298
|
+
self.policy = policy
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _capability_policy(task: AgentTask, sdk: _AntigravitySDK) -> tuple[Any, list[Any]]:
|
|
302
|
+
builtin = sdk.types.BuiltinTools
|
|
303
|
+
if task.permissions.allowed_tools == ():
|
|
304
|
+
if task.permissions.filesystem is FilesystemAccess.READ_ONLY:
|
|
305
|
+
tools = builtin.read_only()
|
|
306
|
+
elif task.permissions.mode is PermissionMode.CAUTIOUS:
|
|
307
|
+
tools = builtin.nondestructive()
|
|
308
|
+
else:
|
|
309
|
+
tools = builtin.all_tools()
|
|
310
|
+
else:
|
|
311
|
+
tools = list(task.permissions.allowed_tools)
|
|
312
|
+
enable_subagents = (
|
|
313
|
+
task.permissions.mode is PermissionMode.PERMISSIVE
|
|
314
|
+
and getattr(builtin, "START_SUBAGENT", None) in tools
|
|
315
|
+
)
|
|
316
|
+
capabilities = sdk.types.CapabilitiesConfig(
|
|
317
|
+
enabled_tools=tools,
|
|
318
|
+
enable_subagents=enable_subagents,
|
|
319
|
+
)
|
|
320
|
+
policies = [] if task.permissions.mode is PermissionMode.STRICT else [sdk.policy.allow_all()]
|
|
321
|
+
return capabilities, policies
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _workspaces(task: AgentTask) -> list[str]:
|
|
325
|
+
if task.working_directory is None:
|
|
326
|
+
return []
|
|
327
|
+
return [str(task.working_directory)]
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _conversation_id(task: AgentTask) -> str | None:
|
|
331
|
+
if task.resume_from is not None:
|
|
332
|
+
return task.resume_from.session_id
|
|
333
|
+
return task.session_id
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _runtime_dir(name: str) -> Path:
|
|
337
|
+
path = Path(gettempdir()) / "agent-runtime-kit" / name
|
|
338
|
+
path.mkdir(mode=0o700, parents=True, exist_ok=True)
|
|
339
|
+
return path
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def _usage_from(value: Any) -> Usage:
|
|
343
|
+
prompt_tokens = _optional_int(getattr(value, "prompt_token_count", None))
|
|
344
|
+
output_tokens = _optional_int(getattr(value, "candidates_token_count", None))
|
|
345
|
+
thoughts = _optional_int(getattr(value, "thoughts_token_count", None))
|
|
346
|
+
cache_read = _optional_int(getattr(value, "cached_content_token_count", None))
|
|
347
|
+
total = _optional_int(getattr(value, "total_token_count", None))
|
|
348
|
+
return Usage(
|
|
349
|
+
input_tokens=max(prompt_tokens - cache_read, 0),
|
|
350
|
+
output_tokens=output_tokens + thoughts,
|
|
351
|
+
cache_read_tokens=cache_read,
|
|
352
|
+
total_tokens=total,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _tool_arguments(value: Any) -> Mapping[str, Any]:
|
|
357
|
+
args = getattr(value, "args", {})
|
|
358
|
+
return dict(args) if isinstance(args, Mapping) else {}
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def _tool_name(value: Any) -> str:
|
|
362
|
+
return str(getattr(value, "value", value) or "tool")
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
async def _maybe_await(value: Any) -> Any:
|
|
366
|
+
if hasattr(value, "__await__"):
|
|
367
|
+
return await value
|
|
368
|
+
return value
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _optional_str(value: Any) -> str | None:
|
|
372
|
+
return str(value) if value else None
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def _optional_int(value: Any) -> int:
|
|
376
|
+
try:
|
|
377
|
+
return int(value or 0)
|
|
378
|
+
except (TypeError, ValueError):
|
|
379
|
+
return 0
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
"""Claude Agent SDK runtime adapter."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import inspect
|
|
6
|
+
from collections.abc import AsyncIterator, Iterable, Mapping
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from agent_runtime_kit._types import (
|
|
10
|
+
AgentCapabilities,
|
|
11
|
+
AgentResult,
|
|
12
|
+
AgentRuntimeKind,
|
|
13
|
+
AgentTask,
|
|
14
|
+
PermissionMode,
|
|
15
|
+
RuntimeAvailability,
|
|
16
|
+
ToolCallAudit,
|
|
17
|
+
Usage,
|
|
18
|
+
)
|
|
19
|
+
from agent_runtime_kit.adapters._common import (
|
|
20
|
+
ensure_supported_model,
|
|
21
|
+
filter_supported_kwargs,
|
|
22
|
+
metadata_str,
|
|
23
|
+
output_schema_from,
|
|
24
|
+
package_availability,
|
|
25
|
+
parse_json_output,
|
|
26
|
+
)
|
|
27
|
+
from agent_runtime_kit.events import (
|
|
28
|
+
output_delta_event,
|
|
29
|
+
safe_emit,
|
|
30
|
+
task_completed_event,
|
|
31
|
+
task_failed_event,
|
|
32
|
+
task_started_event,
|
|
33
|
+
tool_completed_event,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ClaudeAgentRuntime:
|
|
38
|
+
"""Run tasks through ``claude-agent-sdk`` using the shared runtime API."""
|
|
39
|
+
|
|
40
|
+
kind = AgentRuntimeKind.CLAUDE_AGENT_SDK
|
|
41
|
+
capabilities = AgentCapabilities(
|
|
42
|
+
mcp_support=True,
|
|
43
|
+
working_directory=True,
|
|
44
|
+
session_resume=True,
|
|
45
|
+
structured_output=True,
|
|
46
|
+
streaming=True,
|
|
47
|
+
tool_audit=True,
|
|
48
|
+
cancellation=False,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
*,
|
|
54
|
+
default_model: str = "claude-sonnet-4-6",
|
|
55
|
+
supported_models: tuple[str, ...] | None = None,
|
|
56
|
+
query_func: Any | None = None,
|
|
57
|
+
options_cls: Any | None = None,
|
|
58
|
+
) -> None:
|
|
59
|
+
self._default_model = default_model
|
|
60
|
+
self._supported_models = supported_models
|
|
61
|
+
self._query_func = query_func
|
|
62
|
+
self._options_cls = options_cls
|
|
63
|
+
|
|
64
|
+
def availability(self) -> RuntimeAvailability:
|
|
65
|
+
"""Report Claude Agent SDK package availability."""
|
|
66
|
+
|
|
67
|
+
if self._query_func is not None:
|
|
68
|
+
return RuntimeAvailability.ok(self.kind, package="claude-agent-sdk")
|
|
69
|
+
return package_availability(
|
|
70
|
+
self.kind,
|
|
71
|
+
module_name="claude_agent_sdk",
|
|
72
|
+
package_name="claude-agent-sdk",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
async def run(self, task: AgentTask) -> AgentResult:
|
|
76
|
+
"""Execute one task with Claude Agent SDK."""
|
|
77
|
+
|
|
78
|
+
await safe_emit(task, task_started_event(task, self.kind))
|
|
79
|
+
model = self._model(task)
|
|
80
|
+
try:
|
|
81
|
+
ensure_supported_model(
|
|
82
|
+
kind=self.kind,
|
|
83
|
+
model=model,
|
|
84
|
+
supported_models=self._supported_models,
|
|
85
|
+
)
|
|
86
|
+
query_func, options_cls = self._load_sdk()
|
|
87
|
+
options = self._build_options(task, model, options_cls)
|
|
88
|
+
messages = await _collect_messages(query_func(prompt=task.goal, options=options))
|
|
89
|
+
result = _translate_messages(task, messages, model=model)
|
|
90
|
+
except Exception as exc:
|
|
91
|
+
await safe_emit(task, task_failed_event(task, self.kind, error=str(exc)))
|
|
92
|
+
raise
|
|
93
|
+
|
|
94
|
+
for tool_call in result.tool_calls:
|
|
95
|
+
await safe_emit(task, tool_completed_event(task, self.kind, tool_call))
|
|
96
|
+
if result.output:
|
|
97
|
+
await safe_emit(task, output_delta_event(task, self.kind, text=result.output))
|
|
98
|
+
if result.error:
|
|
99
|
+
await safe_emit(task, task_failed_event(task, self.kind, error=result.error))
|
|
100
|
+
else:
|
|
101
|
+
await safe_emit(task, task_completed_event(task, self.kind, result))
|
|
102
|
+
return result
|
|
103
|
+
|
|
104
|
+
async def cancel(self, task_id: str) -> None:
|
|
105
|
+
"""Claude ``query`` calls do not expose a portable cancellation handle."""
|
|
106
|
+
|
|
107
|
+
del task_id
|
|
108
|
+
|
|
109
|
+
def _load_sdk(self) -> tuple[Any, Any]:
|
|
110
|
+
if self._query_func is not None and self._options_cls is not None:
|
|
111
|
+
return self._query_func, self._options_cls
|
|
112
|
+
try:
|
|
113
|
+
from claude_agent_sdk import ClaudeAgentOptions, query # type: ignore[import-not-found]
|
|
114
|
+
except ImportError as exc:
|
|
115
|
+
raise RuntimeError(
|
|
116
|
+
"claude-agent-sdk is not installed. Install agent-runtime-kit[claude]."
|
|
117
|
+
) from exc
|
|
118
|
+
return self._query_func or query, self._options_cls or ClaudeAgentOptions
|
|
119
|
+
|
|
120
|
+
def _build_options(self, task: AgentTask, model: str, options_cls: Any) -> Any:
|
|
121
|
+
metadata = task.metadata
|
|
122
|
+
kwargs: dict[str, Any] = {
|
|
123
|
+
"model": model,
|
|
124
|
+
"allowed_tools": list(task.permissions.allowed_tools),
|
|
125
|
+
"disallowed_tools": list(task.permissions.disallowed_tools),
|
|
126
|
+
"permission_mode": _permission_mode(task.permissions.mode),
|
|
127
|
+
}
|
|
128
|
+
if task.system:
|
|
129
|
+
kwargs["system_prompt"] = task.system
|
|
130
|
+
if task.working_directory is not None:
|
|
131
|
+
kwargs["cwd"] = task.working_directory
|
|
132
|
+
if task.mcp_servers:
|
|
133
|
+
kwargs["mcp_servers"] = {
|
|
134
|
+
server.name: {
|
|
135
|
+
"type": "stdio",
|
|
136
|
+
"command": server.command,
|
|
137
|
+
"args": list(server.args),
|
|
138
|
+
"env": dict(server.env),
|
|
139
|
+
}
|
|
140
|
+
for server in task.mcp_servers
|
|
141
|
+
}
|
|
142
|
+
if task.resume_from is not None:
|
|
143
|
+
kwargs["resume"] = task.resume_from.session_id
|
|
144
|
+
elif task.session_id:
|
|
145
|
+
kwargs["resume"] = task.session_id
|
|
146
|
+
if task.budget_usd is not None:
|
|
147
|
+
kwargs["max_budget_usd"] = task.budget_usd
|
|
148
|
+
output_schema = output_schema_from(task.output_schema, metadata)
|
|
149
|
+
if output_schema is not None:
|
|
150
|
+
kwargs["output_format"] = {"type": "json_schema", "schema": dict(output_schema)}
|
|
151
|
+
setting_sources = metadata.get("setting_sources")
|
|
152
|
+
if isinstance(setting_sources, list):
|
|
153
|
+
kwargs["setting_sources"] = [str(item) for item in setting_sources]
|
|
154
|
+
return options_cls(**filter_supported_kwargs(options_cls, kwargs))
|
|
155
|
+
|
|
156
|
+
def _model(self, task: AgentTask) -> str:
|
|
157
|
+
return metadata_str(task.metadata, "model") or self._default_model
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
async def _collect_messages(candidate: Any) -> list[Any]:
|
|
161
|
+
if inspect.isawaitable(candidate):
|
|
162
|
+
candidate = await candidate
|
|
163
|
+
if hasattr(candidate, "__aiter__"):
|
|
164
|
+
return [message async for message in _as_async_iter(candidate)]
|
|
165
|
+
if isinstance(candidate, Iterable) and not isinstance(candidate, bytes | str | Mapping):
|
|
166
|
+
return list(candidate)
|
|
167
|
+
return [candidate]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
async def _as_async_iter(candidate: AsyncIterator[Any]) -> AsyncIterator[Any]:
|
|
171
|
+
async for item in candidate:
|
|
172
|
+
yield item
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _translate_messages(task: AgentTask, messages: list[Any], *, model: str) -> AgentResult:
|
|
176
|
+
content_parts: list[str] = []
|
|
177
|
+
tool_calls: list[ToolCallAudit] = []
|
|
178
|
+
usage = Usage()
|
|
179
|
+
cost_usd = 0.0
|
|
180
|
+
session_id = task.session_id
|
|
181
|
+
rounds = 0
|
|
182
|
+
error: str | None = None
|
|
183
|
+
structured_output: Any | None = None
|
|
184
|
+
|
|
185
|
+
for message in messages:
|
|
186
|
+
message_type = _message_type(message)
|
|
187
|
+
if message_type in {"AssistantMessage", "assistant"}:
|
|
188
|
+
text, tools = _assistant_content(message)
|
|
189
|
+
content_parts.extend(text)
|
|
190
|
+
tool_calls.extend(tools)
|
|
191
|
+
session_id = _optional_str(_field(message, "session_id")) or session_id
|
|
192
|
+
usage = _usage_from(_field(message, "usage"), current=usage)
|
|
193
|
+
message_error = _field(message, "error")
|
|
194
|
+
if message_error:
|
|
195
|
+
error = str(message_error)
|
|
196
|
+
elif message_type in {"ResultMessage", "result"}:
|
|
197
|
+
result_text = _field(message, "result")
|
|
198
|
+
if result_text and not content_parts:
|
|
199
|
+
content_parts.append(str(result_text))
|
|
200
|
+
structured_output = _field(message, "structured_output", structured_output)
|
|
201
|
+
cost_usd = float(_field(message, "total_cost_usd", cost_usd) or cost_usd)
|
|
202
|
+
usage = _usage_from(_field(message, "usage"), current=usage)
|
|
203
|
+
rounds = int(_field(message, "num_turns", rounds) or rounds)
|
|
204
|
+
session_id = _optional_str(_field(message, "session_id")) or session_id
|
|
205
|
+
if _field(message, "is_error", False):
|
|
206
|
+
errors = _field(message, "errors", ()) or ()
|
|
207
|
+
error = "; ".join(str(item) for item in errors) or "Claude Agent SDK task failed"
|
|
208
|
+
elif isinstance(message, Mapping):
|
|
209
|
+
if message.get("content"):
|
|
210
|
+
content_parts.append(str(message["content"]))
|
|
211
|
+
if message.get("error"):
|
|
212
|
+
error = str(message["error"])
|
|
213
|
+
|
|
214
|
+
output = "\n".join(part for part in content_parts if part).strip()
|
|
215
|
+
if (
|
|
216
|
+
structured_output is None
|
|
217
|
+
and output_schema_from(task.output_schema, task.metadata) is not None
|
|
218
|
+
):
|
|
219
|
+
structured_output = parse_json_output(output)
|
|
220
|
+
usage = Usage(
|
|
221
|
+
input_tokens=usage.input_tokens,
|
|
222
|
+
output_tokens=usage.output_tokens,
|
|
223
|
+
cache_read_tokens=usage.cache_read_tokens,
|
|
224
|
+
cache_creation_tokens=usage.cache_creation_tokens,
|
|
225
|
+
total_tokens=usage.total_tokens,
|
|
226
|
+
cost_usd=cost_usd,
|
|
227
|
+
)
|
|
228
|
+
return AgentResult(
|
|
229
|
+
output=output,
|
|
230
|
+
finish_reason="failed" if error else "done",
|
|
231
|
+
error=error,
|
|
232
|
+
parsed_output=structured_output,
|
|
233
|
+
usage=usage,
|
|
234
|
+
tool_calls=tuple(tool_calls),
|
|
235
|
+
session_id=session_id,
|
|
236
|
+
rounds=rounds,
|
|
237
|
+
metadata={"model": model, "sdk": "claude_agent_sdk"},
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _permission_mode(mode: PermissionMode) -> str:
|
|
242
|
+
if mode is PermissionMode.STRICT:
|
|
243
|
+
return "plan"
|
|
244
|
+
if mode is PermissionMode.CAUTIOUS:
|
|
245
|
+
return "acceptEdits"
|
|
246
|
+
if mode is PermissionMode.PERMISSIVE:
|
|
247
|
+
return "bypassPermissions"
|
|
248
|
+
return "default"
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _message_type(message: Any) -> str:
|
|
252
|
+
if isinstance(message, Mapping):
|
|
253
|
+
return str(message.get("type") or "")
|
|
254
|
+
return type(message).__name__
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _assistant_content(message: Any) -> tuple[list[str], list[ToolCallAudit]]:
|
|
258
|
+
content = _field(message, "content", ())
|
|
259
|
+
if isinstance(content, str):
|
|
260
|
+
return [content], []
|
|
261
|
+
if not isinstance(content, Iterable):
|
|
262
|
+
return [], []
|
|
263
|
+
text_parts: list[str] = []
|
|
264
|
+
tool_calls: list[ToolCallAudit] = []
|
|
265
|
+
for block in content:
|
|
266
|
+
block_type = _message_type(block)
|
|
267
|
+
if block_type in {"TextBlock", "text"}:
|
|
268
|
+
text_parts.append(str(_field(block, "text", "")))
|
|
269
|
+
elif block_type in {"ToolUseBlock", "tool_use"}:
|
|
270
|
+
name = str(_field(block, "name", "tool"))
|
|
271
|
+
raw_input = _field(block, "input", {})
|
|
272
|
+
arguments = raw_input if isinstance(raw_input, Mapping) else {}
|
|
273
|
+
tool_calls.append(ToolCallAudit(tool_name=name, arguments=arguments))
|
|
274
|
+
return text_parts, tool_calls
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _field(value: Any, name: str, default: Any = None) -> Any:
|
|
278
|
+
if isinstance(value, Mapping):
|
|
279
|
+
return value.get(name, default)
|
|
280
|
+
return getattr(value, name, default)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _usage_from(value: Any, *, current: Usage) -> Usage:
|
|
284
|
+
if not isinstance(value, Mapping):
|
|
285
|
+
return current
|
|
286
|
+
input_tokens = int(value.get("input_tokens") or current.input_tokens)
|
|
287
|
+
output_tokens = int(value.get("output_tokens") or current.output_tokens)
|
|
288
|
+
cache_creation = int(value.get("cache_creation_input_tokens") or current.cache_creation_tokens)
|
|
289
|
+
cache_read = int(value.get("cache_read_input_tokens") or current.cache_read_tokens)
|
|
290
|
+
total = input_tokens + output_tokens + cache_creation + cache_read
|
|
291
|
+
return Usage(
|
|
292
|
+
input_tokens=input_tokens,
|
|
293
|
+
output_tokens=output_tokens,
|
|
294
|
+
cache_read_tokens=cache_read,
|
|
295
|
+
cache_creation_tokens=cache_creation,
|
|
296
|
+
total_tokens=total,
|
|
297
|
+
cost_usd=current.cost_usd,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _optional_str(value: Any) -> str | None:
|
|
302
|
+
return str(value) if value else None
|