codebuddy-agent-sdk 0.3.7__py3-none-manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codebuddy_agent_sdk/__init__.py +133 -0
- codebuddy_agent_sdk/_binary.py +150 -0
- codebuddy_agent_sdk/_errors.py +54 -0
- codebuddy_agent_sdk/_message_parser.py +122 -0
- codebuddy_agent_sdk/_version.py +3 -0
- codebuddy_agent_sdk/bin/codebuddy +0 -0
- codebuddy_agent_sdk/client.py +394 -0
- codebuddy_agent_sdk/mcp/__init__.py +35 -0
- codebuddy_agent_sdk/mcp/create_sdk_mcp_server.py +154 -0
- codebuddy_agent_sdk/mcp/sdk_control_server_transport.py +95 -0
- codebuddy_agent_sdk/mcp/types.py +300 -0
- codebuddy_agent_sdk/py.typed +0 -0
- codebuddy_agent_sdk/query.py +340 -0
- codebuddy_agent_sdk/transport/__init__.py +6 -0
- codebuddy_agent_sdk/transport/base.py +31 -0
- codebuddy_agent_sdk/transport/subprocess.py +341 -0
- codebuddy_agent_sdk/types.py +395 -0
- codebuddy_agent_sdk-0.3.7.dist-info/METADATA +89 -0
- codebuddy_agent_sdk-0.3.7.dist-info/RECORD +20 -0
- codebuddy_agent_sdk-0.3.7.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
"""Type definitions for SDK MCP Server."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import inspect
|
|
6
|
+
from collections.abc import Awaitable, Callable
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from .sdk_control_server_transport import SdkControlServerTransport
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# ============= JSON-RPC Types =============
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class JSONRPCRequest(TypedDict, total=False):
|
|
18
|
+
"""JSON-RPC 2.0 request."""
|
|
19
|
+
|
|
20
|
+
jsonrpc: Literal["2.0"]
|
|
21
|
+
id: str | int
|
|
22
|
+
method: str
|
|
23
|
+
params: dict[str, Any] | list[Any] | None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class JSONRPCError(TypedDict, total=False):
|
|
27
|
+
"""JSON-RPC 2.0 error."""
|
|
28
|
+
|
|
29
|
+
code: int
|
|
30
|
+
message: str
|
|
31
|
+
data: Any
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class JSONRPCResponse(TypedDict, total=False):
|
|
35
|
+
"""JSON-RPC 2.0 response."""
|
|
36
|
+
|
|
37
|
+
jsonrpc: Literal["2.0"]
|
|
38
|
+
id: str | int | None
|
|
39
|
+
result: Any
|
|
40
|
+
error: JSONRPCError
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class JSONRPCNotification(TypedDict, total=False):
|
|
44
|
+
"""JSON-RPC 2.0 notification."""
|
|
45
|
+
|
|
46
|
+
jsonrpc: Literal["2.0"]
|
|
47
|
+
method: str
|
|
48
|
+
params: dict[str, Any] | list[Any] | None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
JSONRPCMessage = JSONRPCRequest | JSONRPCResponse | JSONRPCNotification
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ============= MCP Tool Types =============
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class TextContent(TypedDict, total=False):
|
|
58
|
+
"""Text content in tool result."""
|
|
59
|
+
|
|
60
|
+
type: Literal["text"]
|
|
61
|
+
text: str
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ImageContent(TypedDict, total=False):
|
|
65
|
+
"""Image content in tool result."""
|
|
66
|
+
|
|
67
|
+
type: Literal["image"]
|
|
68
|
+
data: str
|
|
69
|
+
mimeType: str
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class EmbeddedResource(TypedDict, total=False):
|
|
73
|
+
"""Embedded resource content in tool result."""
|
|
74
|
+
|
|
75
|
+
type: Literal["resource"]
|
|
76
|
+
resource: dict[str, Any]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
ToolResultContent = TextContent | ImageContent | EmbeddedResource
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class CallToolResult(TypedDict, total=False):
|
|
83
|
+
"""Result from calling a tool."""
|
|
84
|
+
|
|
85
|
+
content: list[ToolResultContent]
|
|
86
|
+
isError: bool
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# Tool handler type - takes arguments dict and returns CallToolResult
|
|
90
|
+
ToolHandler = Callable[[dict[str, Any]], CallToolResult | Awaitable[CallToolResult]]
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class ToolInputProperty:
|
|
95
|
+
"""Property definition for tool input schema."""
|
|
96
|
+
|
|
97
|
+
type: str
|
|
98
|
+
description: str | None = None
|
|
99
|
+
enum: list[str] | None = None
|
|
100
|
+
default: Any = None
|
|
101
|
+
minimum: float | None = None
|
|
102
|
+
maximum: float | None = None
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclass
|
|
106
|
+
class ToolInputSchema:
|
|
107
|
+
"""JSON Schema for tool input."""
|
|
108
|
+
|
|
109
|
+
type: Literal["object"] = "object"
|
|
110
|
+
properties: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
111
|
+
required: list[str] = field(default_factory=list)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclass
|
|
115
|
+
class SdkMcpToolDefinition:
|
|
116
|
+
"""
|
|
117
|
+
Tool definition for SDK MCP Server.
|
|
118
|
+
|
|
119
|
+
Example:
|
|
120
|
+
```python
|
|
121
|
+
tool_def = SdkMcpToolDefinition(
|
|
122
|
+
name="get_weather",
|
|
123
|
+
description="Get the current weather for a location",
|
|
124
|
+
input_schema=ToolInputSchema(
|
|
125
|
+
properties={
|
|
126
|
+
"location": {"type": "string", "description": "The city name"},
|
|
127
|
+
"units": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
|
128
|
+
},
|
|
129
|
+
required=["location"],
|
|
130
|
+
),
|
|
131
|
+
handler=get_weather_handler,
|
|
132
|
+
)
|
|
133
|
+
```
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
name: str
|
|
137
|
+
description: str
|
|
138
|
+
input_schema: ToolInputSchema
|
|
139
|
+
handler: ToolHandler
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass
|
|
143
|
+
class SdkMcpServerOptions:
|
|
144
|
+
"""
|
|
145
|
+
Options for creating an SDK MCP Server.
|
|
146
|
+
|
|
147
|
+
Attributes:
|
|
148
|
+
name: Server name (must be unique within the session)
|
|
149
|
+
version: Server version (defaults to "1.0.0")
|
|
150
|
+
tools: List of tool definitions to register
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
name: str
|
|
154
|
+
version: str = "1.0.0"
|
|
155
|
+
tools: list[SdkMcpToolDefinition] = field(default_factory=list)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@dataclass
|
|
159
|
+
class SdkMcpServerResult:
|
|
160
|
+
"""
|
|
161
|
+
Result type for create_sdk_mcp_server.
|
|
162
|
+
|
|
163
|
+
Attributes:
|
|
164
|
+
type: Type discriminator - always "sdk" for SDK MCP servers
|
|
165
|
+
name: Server name
|
|
166
|
+
server: The MCP server instance
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
type: Literal["sdk"]
|
|
170
|
+
name: str
|
|
171
|
+
server: SdkMcpServer
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class SdkMcpServer:
|
|
175
|
+
"""
|
|
176
|
+
SDK MCP Server implementation.
|
|
177
|
+
|
|
178
|
+
This class implements an MCP server that runs within the SDK process
|
|
179
|
+
and communicates with the CLI via the control protocol.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
def __init__(self, options: SdkMcpServerOptions):
|
|
183
|
+
self.name = options.name
|
|
184
|
+
self.version = options.version
|
|
185
|
+
self.tools: dict[str, SdkMcpToolDefinition] = {}
|
|
186
|
+
self._transport: SdkControlServerTransport | None = None
|
|
187
|
+
|
|
188
|
+
# Register tools
|
|
189
|
+
for tool_def in options.tools:
|
|
190
|
+
self.tools[tool_def.name] = tool_def
|
|
191
|
+
|
|
192
|
+
def connect(self, transport: SdkControlServerTransport) -> None:
|
|
193
|
+
"""Connect the server to a transport."""
|
|
194
|
+
self._transport = transport
|
|
195
|
+
|
|
196
|
+
async def handle_message(self, message: JSONRPCMessage) -> JSONRPCMessage | None:
|
|
197
|
+
"""Handle an incoming JSON-RPC message."""
|
|
198
|
+
# Check if it's a request (has method and id)
|
|
199
|
+
if "method" not in message:
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
method = cast(str, message.get("method", ""))
|
|
203
|
+
msg_id = cast("str | int | None", message.get("id"))
|
|
204
|
+
params = cast("dict[str, Any] | None", message.get("params", {}))
|
|
205
|
+
|
|
206
|
+
if method == "initialize":
|
|
207
|
+
return await self._handle_initialize(msg_id, params)
|
|
208
|
+
elif method == "tools/list":
|
|
209
|
+
return await self._handle_tools_list(msg_id)
|
|
210
|
+
elif method == "tools/call":
|
|
211
|
+
return await self._handle_tools_call(msg_id, params)
|
|
212
|
+
elif method == "notifications/initialized":
|
|
213
|
+
# Notification, no response needed
|
|
214
|
+
return None
|
|
215
|
+
else:
|
|
216
|
+
# Unknown method
|
|
217
|
+
if msg_id is not None:
|
|
218
|
+
return self._create_error_response(msg_id, -32601, f"Method not found: {method}")
|
|
219
|
+
return None
|
|
220
|
+
|
|
221
|
+
async def _handle_initialize(
|
|
222
|
+
self, msg_id: str | int | None, params: dict[str, Any] | None
|
|
223
|
+
) -> JSONRPCMessage:
|
|
224
|
+
"""Handle initialize request."""
|
|
225
|
+
result = {
|
|
226
|
+
"protocolVersion": "2024-11-05",
|
|
227
|
+
"capabilities": {
|
|
228
|
+
"tools": {},
|
|
229
|
+
},
|
|
230
|
+
"serverInfo": {
|
|
231
|
+
"name": self.name,
|
|
232
|
+
"version": self.version,
|
|
233
|
+
},
|
|
234
|
+
}
|
|
235
|
+
return self._create_response(msg_id, result)
|
|
236
|
+
|
|
237
|
+
async def _handle_tools_list(self, msg_id: str | int | None) -> JSONRPCMessage:
|
|
238
|
+
"""Handle tools/list request."""
|
|
239
|
+
tools_list = []
|
|
240
|
+
for tool_def in self.tools.values():
|
|
241
|
+
tools_list.append(
|
|
242
|
+
{
|
|
243
|
+
"name": tool_def.name,
|
|
244
|
+
"description": tool_def.description,
|
|
245
|
+
"inputSchema": {
|
|
246
|
+
"type": tool_def.input_schema.type,
|
|
247
|
+
"properties": tool_def.input_schema.properties,
|
|
248
|
+
"required": tool_def.input_schema.required,
|
|
249
|
+
},
|
|
250
|
+
}
|
|
251
|
+
)
|
|
252
|
+
return self._create_response(msg_id, {"tools": tools_list})
|
|
253
|
+
|
|
254
|
+
async def _handle_tools_call(
|
|
255
|
+
self, msg_id: str | int | None, params: dict[str, Any] | None
|
|
256
|
+
) -> JSONRPCMessage:
|
|
257
|
+
"""Handle tools/call request."""
|
|
258
|
+
if not isinstance(params, dict):
|
|
259
|
+
return self._create_error_response(msg_id, -32602, "Invalid params")
|
|
260
|
+
|
|
261
|
+
tool_name = params.get("name", "")
|
|
262
|
+
arguments = params.get("arguments", {})
|
|
263
|
+
|
|
264
|
+
tool_def = self.tools.get(tool_name)
|
|
265
|
+
if not tool_def:
|
|
266
|
+
return self._create_error_response(msg_id, -32602, f"Tool not found: {tool_name}")
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
# Call the handler
|
|
270
|
+
result = tool_def.handler(arguments)
|
|
271
|
+
# Handle async handlers
|
|
272
|
+
if inspect.isawaitable(result):
|
|
273
|
+
result = await result
|
|
274
|
+
|
|
275
|
+
return self._create_response(msg_id, result)
|
|
276
|
+
except Exception as e:
|
|
277
|
+
# Return error as tool result
|
|
278
|
+
error_result: CallToolResult = {
|
|
279
|
+
"content": [{"type": "text", "text": str(e)}],
|
|
280
|
+
"isError": True,
|
|
281
|
+
}
|
|
282
|
+
return self._create_response(msg_id, error_result)
|
|
283
|
+
|
|
284
|
+
def _create_response(self, msg_id: str | int | None, result: Any) -> JSONRPCResponse:
|
|
285
|
+
"""Create a JSON-RPC response."""
|
|
286
|
+
return {
|
|
287
|
+
"jsonrpc": "2.0",
|
|
288
|
+
"id": msg_id,
|
|
289
|
+
"result": result,
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
def _create_error_response(
|
|
293
|
+
self, msg_id: str | int | None, code: int, message: str
|
|
294
|
+
) -> JSONRPCResponse:
|
|
295
|
+
"""Create a JSON-RPC error response."""
|
|
296
|
+
return {
|
|
297
|
+
"jsonrpc": "2.0",
|
|
298
|
+
"id": msg_id,
|
|
299
|
+
"error": {"code": code, "message": message},
|
|
300
|
+
}
|
|
File without changes
|
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
"""Query function for one-shot interactions with CodeBuddy."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from collections.abc import AsyncIterable, AsyncIterator
|
|
8
|
+
from dataclasses import asdict
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from ._errors import ExecutionError
|
|
12
|
+
from ._message_parser import parse_message
|
|
13
|
+
from .transport import SubprocessTransport, Transport
|
|
14
|
+
from .types import (
|
|
15
|
+
AppendSystemPrompt,
|
|
16
|
+
CanUseToolOptions,
|
|
17
|
+
CodeBuddyAgentOptions,
|
|
18
|
+
ErrorMessage,
|
|
19
|
+
HookCallback,
|
|
20
|
+
HookMatcher,
|
|
21
|
+
Message,
|
|
22
|
+
ResultMessage,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def query(
|
|
27
|
+
*,
|
|
28
|
+
prompt: str | AsyncIterable[dict[str, Any]],
|
|
29
|
+
options: CodeBuddyAgentOptions | None = None,
|
|
30
|
+
transport: Transport | None = None,
|
|
31
|
+
) -> AsyncIterator[Message]:
|
|
32
|
+
"""
|
|
33
|
+
Query CodeBuddy for one-shot or unidirectional streaming interactions.
|
|
34
|
+
|
|
35
|
+
This function is ideal for simple, stateless queries where you don't need
|
|
36
|
+
bidirectional communication or conversation management. For interactive,
|
|
37
|
+
stateful conversations, use CodeBuddySDKClient instead.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
prompt: The prompt to send to CodeBuddy. Can be a string for single-shot
|
|
41
|
+
queries or an AsyncIterable[dict] for streaming mode.
|
|
42
|
+
options: Optional configuration (defaults to CodeBuddyAgentOptions() if None).
|
|
43
|
+
transport: Optional transport implementation. If provided, this will be used
|
|
44
|
+
instead of the default subprocess transport.
|
|
45
|
+
|
|
46
|
+
Yields:
|
|
47
|
+
Messages from the conversation.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
```python
|
|
51
|
+
async for message in query(prompt="What is 2+2?"):
|
|
52
|
+
print(message)
|
|
53
|
+
```
|
|
54
|
+
"""
|
|
55
|
+
if options is None:
|
|
56
|
+
options = CodeBuddyAgentOptions()
|
|
57
|
+
|
|
58
|
+
os.environ["CODEBUDDY_CODE_ENTRYPOINT"] = "sdk-py"
|
|
59
|
+
|
|
60
|
+
# Transport handles SDK MCP server extraction automatically
|
|
61
|
+
if transport is None:
|
|
62
|
+
transport = SubprocessTransport(options=options, prompt=prompt)
|
|
63
|
+
|
|
64
|
+
await transport.connect()
|
|
65
|
+
|
|
66
|
+
# Hook callback registry
|
|
67
|
+
hook_callbacks: dict[str, HookCallback] = {}
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
hook_callbacks = await _send_initialize(transport, options)
|
|
71
|
+
await _send_prompt(transport, prompt)
|
|
72
|
+
|
|
73
|
+
async for line in transport.read():
|
|
74
|
+
if not line:
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
data = json.loads(line)
|
|
79
|
+
|
|
80
|
+
# Handle control requests (hooks, permissions, MCP messages)
|
|
81
|
+
if data.get("type") == "control_request":
|
|
82
|
+
await _handle_control_request(transport, data, options, hook_callbacks)
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
message = parse_message(data)
|
|
86
|
+
if message:
|
|
87
|
+
# Check for execution error BEFORE yielding
|
|
88
|
+
if isinstance(message, ResultMessage):
|
|
89
|
+
if message.is_error and message.errors and len(message.errors) > 0:
|
|
90
|
+
raise ExecutionError(message.errors, message.subtype)
|
|
91
|
+
yield message
|
|
92
|
+
break
|
|
93
|
+
|
|
94
|
+
yield message
|
|
95
|
+
|
|
96
|
+
if isinstance(message, ErrorMessage):
|
|
97
|
+
break
|
|
98
|
+
|
|
99
|
+
except json.JSONDecodeError:
|
|
100
|
+
continue # Ignore non-JSON lines
|
|
101
|
+
|
|
102
|
+
finally:
|
|
103
|
+
await transport.close()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
async def _send_initialize(
|
|
107
|
+
transport: Transport,
|
|
108
|
+
options: CodeBuddyAgentOptions,
|
|
109
|
+
) -> dict[str, HookCallback]:
|
|
110
|
+
"""Send initialization control request.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Hook callbacks registry (callback_id -> hook function)
|
|
114
|
+
"""
|
|
115
|
+
hooks_config, hook_callbacks = _build_hooks_config(options.hooks)
|
|
116
|
+
agents_config = (
|
|
117
|
+
{name: asdict(agent) for name, agent in options.agents.items()} if options.agents else None
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Parse system_prompt config
|
|
121
|
+
system_prompt: str | None = None
|
|
122
|
+
append_system_prompt: str | None = None
|
|
123
|
+
if isinstance(options.system_prompt, str):
|
|
124
|
+
system_prompt = options.system_prompt
|
|
125
|
+
elif isinstance(options.system_prompt, AppendSystemPrompt):
|
|
126
|
+
append_system_prompt = options.system_prompt.append
|
|
127
|
+
|
|
128
|
+
# Get SDK MCP server names from transport
|
|
129
|
+
sdk_mcp_server_names = transport.sdk_mcp_server_names
|
|
130
|
+
|
|
131
|
+
request = {
|
|
132
|
+
"type": "control_request",
|
|
133
|
+
"request_id": f"init_{id(options)}",
|
|
134
|
+
"request": {
|
|
135
|
+
"subtype": "initialize",
|
|
136
|
+
"hooks": hooks_config,
|
|
137
|
+
"systemPrompt": system_prompt,
|
|
138
|
+
"appendSystemPrompt": append_system_prompt,
|
|
139
|
+
"agents": agents_config,
|
|
140
|
+
# Include SDK MCP server names from transport
|
|
141
|
+
"sdkMcpServers": sdk_mcp_server_names if sdk_mcp_server_names else None,
|
|
142
|
+
},
|
|
143
|
+
}
|
|
144
|
+
await transport.write(json.dumps(request))
|
|
145
|
+
return hook_callbacks
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
async def _send_prompt(transport: Transport, prompt: str | AsyncIterable[dict[str, Any]]) -> None:
|
|
149
|
+
"""Send user prompt."""
|
|
150
|
+
if isinstance(prompt, str):
|
|
151
|
+
message = {
|
|
152
|
+
"type": "user",
|
|
153
|
+
"session_id": "",
|
|
154
|
+
"message": {"role": "user", "content": prompt},
|
|
155
|
+
"parent_tool_use_id": None,
|
|
156
|
+
}
|
|
157
|
+
await transport.write(json.dumps(message))
|
|
158
|
+
else:
|
|
159
|
+
async for msg in prompt:
|
|
160
|
+
await transport.write(json.dumps(msg))
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
async def _handle_control_request(
|
|
164
|
+
transport: Transport,
|
|
165
|
+
data: dict[str, Any],
|
|
166
|
+
options: CodeBuddyAgentOptions,
|
|
167
|
+
hook_callbacks: dict[str, HookCallback],
|
|
168
|
+
) -> None:
|
|
169
|
+
"""Handle control request from CLI."""
|
|
170
|
+
request_id = data.get("request_id", "")
|
|
171
|
+
request = data.get("request", {})
|
|
172
|
+
subtype = request.get("subtype", "")
|
|
173
|
+
|
|
174
|
+
if subtype == "hook_callback":
|
|
175
|
+
# Handle hook callback
|
|
176
|
+
callback_id = request.get("callback_id", "")
|
|
177
|
+
hook_input = request.get("input", {})
|
|
178
|
+
tool_use_id = request.get("tool_use_id")
|
|
179
|
+
|
|
180
|
+
# Find and execute the hook using callback registry
|
|
181
|
+
response = await _execute_hook(callback_id, hook_input, tool_use_id, hook_callbacks)
|
|
182
|
+
|
|
183
|
+
# Send response
|
|
184
|
+
control_response = {
|
|
185
|
+
"type": "control_response",
|
|
186
|
+
"response": {
|
|
187
|
+
"subtype": "success",
|
|
188
|
+
"request_id": request_id,
|
|
189
|
+
"response": response,
|
|
190
|
+
},
|
|
191
|
+
}
|
|
192
|
+
await transport.write(json.dumps(control_response))
|
|
193
|
+
|
|
194
|
+
elif subtype == "can_use_tool":
|
|
195
|
+
await _handle_permission_request(transport, request_id, request, options)
|
|
196
|
+
|
|
197
|
+
elif subtype == "mcp_message":
|
|
198
|
+
# MCP messages are handled at the transport level
|
|
199
|
+
if isinstance(transport, SubprocessTransport):
|
|
200
|
+
await transport.handle_mcp_message_request(request_id, request)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
async def _handle_permission_request(
|
|
204
|
+
transport: Transport,
|
|
205
|
+
request_id: str,
|
|
206
|
+
request: dict[str, Any],
|
|
207
|
+
options: CodeBuddyAgentOptions,
|
|
208
|
+
) -> None:
|
|
209
|
+
"""Handle permission request from CLI."""
|
|
210
|
+
tool_name = request.get("tool_name", "")
|
|
211
|
+
input_data = request.get("input", {})
|
|
212
|
+
tool_use_id = request.get("tool_use_id", "")
|
|
213
|
+
agent_id = request.get("agent_id")
|
|
214
|
+
|
|
215
|
+
can_use_tool = options.can_use_tool
|
|
216
|
+
|
|
217
|
+
# Default deny if no callback provided
|
|
218
|
+
if not can_use_tool:
|
|
219
|
+
response = {
|
|
220
|
+
"type": "control_response",
|
|
221
|
+
"response": {
|
|
222
|
+
"subtype": "success",
|
|
223
|
+
"request_id": request_id,
|
|
224
|
+
"response": {
|
|
225
|
+
"allowed": False,
|
|
226
|
+
"reason": "No permission handler provided",
|
|
227
|
+
"tool_use_id": tool_use_id,
|
|
228
|
+
},
|
|
229
|
+
},
|
|
230
|
+
}
|
|
231
|
+
await transport.write(json.dumps(response))
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
try:
|
|
235
|
+
callback_options = CanUseToolOptions(
|
|
236
|
+
tool_use_id=tool_use_id,
|
|
237
|
+
signal=None,
|
|
238
|
+
agent_id=agent_id,
|
|
239
|
+
suggestions=request.get("permission_suggestions"),
|
|
240
|
+
blocked_path=request.get("blocked_path"),
|
|
241
|
+
decision_reason=request.get("decision_reason"),
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
result = await can_use_tool(tool_name, input_data, callback_options)
|
|
245
|
+
|
|
246
|
+
if result.behavior == "allow":
|
|
247
|
+
response_data = {
|
|
248
|
+
"allowed": True,
|
|
249
|
+
"updatedInput": result.updated_input,
|
|
250
|
+
"tool_use_id": tool_use_id,
|
|
251
|
+
}
|
|
252
|
+
else:
|
|
253
|
+
response_data = {
|
|
254
|
+
"allowed": False,
|
|
255
|
+
"reason": result.message,
|
|
256
|
+
"interrupt": result.interrupt,
|
|
257
|
+
"tool_use_id": tool_use_id,
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
response = {
|
|
261
|
+
"type": "control_response",
|
|
262
|
+
"response": {
|
|
263
|
+
"subtype": "success",
|
|
264
|
+
"request_id": request_id,
|
|
265
|
+
"response": response_data,
|
|
266
|
+
},
|
|
267
|
+
}
|
|
268
|
+
await transport.write(json.dumps(response))
|
|
269
|
+
|
|
270
|
+
except Exception as e:
|
|
271
|
+
response = {
|
|
272
|
+
"type": "control_response",
|
|
273
|
+
"response": {
|
|
274
|
+
"subtype": "success",
|
|
275
|
+
"request_id": request_id,
|
|
276
|
+
"response": {
|
|
277
|
+
"allowed": False,
|
|
278
|
+
"reason": str(e),
|
|
279
|
+
"tool_use_id": tool_use_id,
|
|
280
|
+
},
|
|
281
|
+
},
|
|
282
|
+
}
|
|
283
|
+
await transport.write(json.dumps(response))
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
async def _execute_hook(
|
|
287
|
+
callback_id: str,
|
|
288
|
+
hook_input: dict[str, Any],
|
|
289
|
+
tool_use_id: str | None,
|
|
290
|
+
hook_callbacks: dict[str, HookCallback],
|
|
291
|
+
) -> dict[str, Any]:
|
|
292
|
+
"""Execute a hook callback by looking up in the callback registry."""
|
|
293
|
+
hook = hook_callbacks.get(callback_id)
|
|
294
|
+
if not hook:
|
|
295
|
+
return {"continue": True}
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
result = await hook(hook_input, tool_use_id, {"signal": None})
|
|
299
|
+
return dict(result)
|
|
300
|
+
except Exception as e:
|
|
301
|
+
return {"continue": False, "stopReason": str(e)}
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def _build_hooks_config(
|
|
305
|
+
hooks: dict[Any, list[HookMatcher]] | None,
|
|
306
|
+
) -> tuple[dict[str, list[dict[str, Any]]] | None, dict[str, HookCallback]]:
|
|
307
|
+
"""Build hooks configuration for CLI and callback registry.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
Tuple of (config for CLI, callback_id -> hook function mapping)
|
|
311
|
+
"""
|
|
312
|
+
callbacks: dict[str, HookCallback] = {}
|
|
313
|
+
|
|
314
|
+
if not hooks:
|
|
315
|
+
return None, callbacks
|
|
316
|
+
|
|
317
|
+
config: dict[str, list[dict[str, Any]]] = {}
|
|
318
|
+
|
|
319
|
+
for event, matchers in hooks.items():
|
|
320
|
+
event_str = str(event)
|
|
321
|
+
matcher_configs = []
|
|
322
|
+
|
|
323
|
+
for i, m in enumerate(matchers):
|
|
324
|
+
callback_ids = []
|
|
325
|
+
for j, hook in enumerate(m.hooks):
|
|
326
|
+
callback_id = f"hook_{event_str}_{i}_{j}"
|
|
327
|
+
callback_ids.append(callback_id)
|
|
328
|
+
callbacks[callback_id] = hook
|
|
329
|
+
|
|
330
|
+
matcher_configs.append(
|
|
331
|
+
{
|
|
332
|
+
"matcher": m.matcher,
|
|
333
|
+
"hookCallbackIds": callback_ids,
|
|
334
|
+
"timeout": m.timeout,
|
|
335
|
+
}
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
config[event_str] = matcher_configs
|
|
339
|
+
|
|
340
|
+
return (config if config else None), callbacks
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Transport base class for CLI communication."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Transport(ABC):
|
|
10
|
+
"""Abstract transport layer for CLI communication."""
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
async def connect(self) -> None:
|
|
14
|
+
"""Establish connection to CLI."""
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def read(self) -> AsyncIterator[str]:
|
|
18
|
+
"""Read messages from CLI as an async iterator."""
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
async def write(self, data: str) -> None:
|
|
22
|
+
"""Write data to CLI."""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
async def close(self) -> None:
|
|
26
|
+
"""Close the connection."""
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def sdk_mcp_server_names(self) -> list[str]:
|
|
30
|
+
"""Get the list of SDK MCP server names. Override in subclasses."""
|
|
31
|
+
return []
|