claude-agent-sdk 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of claude-agent-sdk might be problematic. Click here for more details.
- claude_agent_sdk/__init__.py +325 -0
- claude_agent_sdk/_errors.py +56 -0
- claude_agent_sdk/_internal/__init__.py +1 -0
- claude_agent_sdk/_internal/client.py +121 -0
- claude_agent_sdk/_internal/message_parser.py +172 -0
- claude_agent_sdk/_internal/query.py +523 -0
- claude_agent_sdk/_internal/transport/__init__.py +68 -0
- claude_agent_sdk/_internal/transport/subprocess_cli.py +456 -0
- claude_agent_sdk/_version.py +3 -0
- claude_agent_sdk/client.py +325 -0
- claude_agent_sdk/py.typed +0 -0
- claude_agent_sdk/query.py +126 -0
- claude_agent_sdk/types.py +412 -0
- claude_agent_sdk-0.0.23.dist-info/METADATA +309 -0
- claude_agent_sdk-0.0.23.dist-info/RECORD +17 -0
- claude_agent_sdk-0.0.23.dist-info/WHEEL +4 -0
- claude_agent_sdk-0.0.23.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
"""Query class for handling bidirectional control protocol."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
|
|
7
|
+
from contextlib import suppress
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
import anyio
|
|
11
|
+
from mcp.types import (
|
|
12
|
+
CallToolRequest,
|
|
13
|
+
CallToolRequestParams,
|
|
14
|
+
ListToolsRequest,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from ..types import (
|
|
18
|
+
PermissionResultAllow,
|
|
19
|
+
PermissionResultDeny,
|
|
20
|
+
SDKControlPermissionRequest,
|
|
21
|
+
SDKControlRequest,
|
|
22
|
+
SDKControlResponse,
|
|
23
|
+
SDKHookCallbackRequest,
|
|
24
|
+
ToolPermissionContext,
|
|
25
|
+
)
|
|
26
|
+
from .transport import Transport
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from mcp.server import Server as McpServer
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Query:
|
|
35
|
+
"""Handles bidirectional control protocol on top of Transport.
|
|
36
|
+
|
|
37
|
+
This class manages:
|
|
38
|
+
- Control request/response routing
|
|
39
|
+
- Hook callbacks
|
|
40
|
+
- Tool permission callbacks
|
|
41
|
+
- Message streaming
|
|
42
|
+
- Initialization handshake
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
transport: Transport,
|
|
48
|
+
is_streaming_mode: bool,
|
|
49
|
+
can_use_tool: Callable[
|
|
50
|
+
[str, dict[str, Any], ToolPermissionContext],
|
|
51
|
+
Awaitable[PermissionResultAllow | PermissionResultDeny],
|
|
52
|
+
]
|
|
53
|
+
| None = None,
|
|
54
|
+
hooks: dict[str, list[dict[str, Any]]] | None = None,
|
|
55
|
+
sdk_mcp_servers: dict[str, "McpServer"] | None = None,
|
|
56
|
+
):
|
|
57
|
+
"""Initialize Query with transport and callbacks.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
transport: Low-level transport for I/O
|
|
61
|
+
is_streaming_mode: Whether using streaming (bidirectional) mode
|
|
62
|
+
can_use_tool: Optional callback for tool permission requests
|
|
63
|
+
hooks: Optional hook configurations
|
|
64
|
+
sdk_mcp_servers: Optional SDK MCP server instances
|
|
65
|
+
"""
|
|
66
|
+
self.transport = transport
|
|
67
|
+
self.is_streaming_mode = is_streaming_mode
|
|
68
|
+
self.can_use_tool = can_use_tool
|
|
69
|
+
self.hooks = hooks or {}
|
|
70
|
+
self.sdk_mcp_servers = sdk_mcp_servers or {}
|
|
71
|
+
|
|
72
|
+
# Control protocol state
|
|
73
|
+
self.pending_control_responses: dict[str, anyio.Event] = {}
|
|
74
|
+
self.pending_control_results: dict[str, dict[str, Any] | Exception] = {}
|
|
75
|
+
self.hook_callbacks: dict[str, Callable[..., Any]] = {}
|
|
76
|
+
self.next_callback_id = 0
|
|
77
|
+
self._request_counter = 0
|
|
78
|
+
|
|
79
|
+
# Message stream
|
|
80
|
+
self._message_send, self._message_receive = anyio.create_memory_object_stream[
|
|
81
|
+
dict[str, Any]
|
|
82
|
+
](max_buffer_size=100)
|
|
83
|
+
self._tg: anyio.abc.TaskGroup | None = None
|
|
84
|
+
self._initialized = False
|
|
85
|
+
self._closed = False
|
|
86
|
+
self._initialization_result: dict[str, Any] | None = None
|
|
87
|
+
|
|
88
|
+
async def initialize(self) -> dict[str, Any] | None:
|
|
89
|
+
"""Initialize control protocol if in streaming mode.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Initialize response with supported commands, or None if not streaming
|
|
93
|
+
"""
|
|
94
|
+
if not self.is_streaming_mode:
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
# Build hooks configuration for initialization
|
|
98
|
+
hooks_config: dict[str, Any] = {}
|
|
99
|
+
if self.hooks:
|
|
100
|
+
for event, matchers in self.hooks.items():
|
|
101
|
+
if matchers:
|
|
102
|
+
hooks_config[event] = []
|
|
103
|
+
for matcher in matchers:
|
|
104
|
+
callback_ids = []
|
|
105
|
+
for callback in matcher.get("hooks", []):
|
|
106
|
+
callback_id = f"hook_{self.next_callback_id}"
|
|
107
|
+
self.next_callback_id += 1
|
|
108
|
+
self.hook_callbacks[callback_id] = callback
|
|
109
|
+
callback_ids.append(callback_id)
|
|
110
|
+
hooks_config[event].append(
|
|
111
|
+
{
|
|
112
|
+
"matcher": matcher.get("matcher"),
|
|
113
|
+
"hookCallbackIds": callback_ids,
|
|
114
|
+
}
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Send initialize request
|
|
118
|
+
request = {
|
|
119
|
+
"subtype": "initialize",
|
|
120
|
+
"hooks": hooks_config if hooks_config else None,
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
response = await self._send_control_request(request)
|
|
124
|
+
self._initialized = True
|
|
125
|
+
self._initialization_result = response # Store for later access
|
|
126
|
+
return response
|
|
127
|
+
|
|
128
|
+
async def start(self) -> None:
|
|
129
|
+
"""Start reading messages from transport."""
|
|
130
|
+
if self._tg is None:
|
|
131
|
+
self._tg = anyio.create_task_group()
|
|
132
|
+
await self._tg.__aenter__()
|
|
133
|
+
self._tg.start_soon(self._read_messages)
|
|
134
|
+
|
|
135
|
+
async def _read_messages(self) -> None:
|
|
136
|
+
"""Read messages from transport and route them."""
|
|
137
|
+
try:
|
|
138
|
+
async for message in self.transport.read_messages():
|
|
139
|
+
if self._closed:
|
|
140
|
+
break
|
|
141
|
+
|
|
142
|
+
msg_type = message.get("type")
|
|
143
|
+
|
|
144
|
+
# Route control messages
|
|
145
|
+
if msg_type == "control_response":
|
|
146
|
+
response = message.get("response", {})
|
|
147
|
+
request_id = response.get("request_id")
|
|
148
|
+
if request_id in self.pending_control_responses:
|
|
149
|
+
event = self.pending_control_responses[request_id]
|
|
150
|
+
if response.get("subtype") == "error":
|
|
151
|
+
self.pending_control_results[request_id] = Exception(
|
|
152
|
+
response.get("error", "Unknown error")
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
self.pending_control_results[request_id] = response
|
|
156
|
+
event.set()
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
elif msg_type == "control_request":
|
|
160
|
+
# Handle incoming control requests from CLI
|
|
161
|
+
# Cast message to SDKControlRequest for type safety
|
|
162
|
+
request: SDKControlRequest = message # type: ignore[assignment]
|
|
163
|
+
if self._tg:
|
|
164
|
+
self._tg.start_soon(self._handle_control_request, request)
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
elif msg_type == "control_cancel_request":
|
|
168
|
+
# Handle cancel requests
|
|
169
|
+
# TODO: Implement cancellation support
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
# Regular SDK messages go to the stream
|
|
173
|
+
await self._message_send.send(message)
|
|
174
|
+
|
|
175
|
+
except anyio.get_cancelled_exc_class():
|
|
176
|
+
# Task was cancelled - this is expected behavior
|
|
177
|
+
logger.debug("Read task cancelled")
|
|
178
|
+
raise # Re-raise to properly handle cancellation
|
|
179
|
+
except Exception as e:
|
|
180
|
+
logger.error(f"Fatal error in message reader: {e}")
|
|
181
|
+
# Put error in stream so iterators can handle it
|
|
182
|
+
await self._message_send.send({"type": "error", "error": str(e)})
|
|
183
|
+
finally:
|
|
184
|
+
# Always signal end of stream
|
|
185
|
+
await self._message_send.send({"type": "end"})
|
|
186
|
+
|
|
187
|
+
async def _handle_control_request(self, request: SDKControlRequest) -> None:
|
|
188
|
+
"""Handle incoming control request from CLI."""
|
|
189
|
+
request_id = request["request_id"]
|
|
190
|
+
request_data = request["request"]
|
|
191
|
+
subtype = request_data["subtype"]
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
response_data: dict[str, Any] = {}
|
|
195
|
+
|
|
196
|
+
if subtype == "can_use_tool":
|
|
197
|
+
permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment]
|
|
198
|
+
# Handle tool permission request
|
|
199
|
+
if not self.can_use_tool:
|
|
200
|
+
raise Exception("canUseTool callback is not provided")
|
|
201
|
+
|
|
202
|
+
context = ToolPermissionContext(
|
|
203
|
+
signal=None, # TODO: Add abort signal support
|
|
204
|
+
suggestions=permission_request.get("permission_suggestions", [])
|
|
205
|
+
or [],
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
response = await self.can_use_tool(
|
|
209
|
+
permission_request["tool_name"],
|
|
210
|
+
permission_request["input"],
|
|
211
|
+
context,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Convert PermissionResult to expected dict format
|
|
215
|
+
if isinstance(response, PermissionResultAllow):
|
|
216
|
+
response_data = {"allow": True}
|
|
217
|
+
if response.updated_input is not None:
|
|
218
|
+
response_data["input"] = response.updated_input
|
|
219
|
+
# TODO: Handle updatedPermissions when control protocol supports it
|
|
220
|
+
elif isinstance(response, PermissionResultDeny):
|
|
221
|
+
response_data = {"allow": False, "reason": response.message}
|
|
222
|
+
# TODO: Handle interrupt flag when control protocol supports it
|
|
223
|
+
else:
|
|
224
|
+
raise TypeError(
|
|
225
|
+
f"Tool permission callback must return PermissionResult (PermissionResultAllow or PermissionResultDeny), got {type(response)}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
elif subtype == "hook_callback":
|
|
229
|
+
hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment]
|
|
230
|
+
# Handle hook callback
|
|
231
|
+
callback_id = hook_callback_request["callback_id"]
|
|
232
|
+
callback = self.hook_callbacks.get(callback_id)
|
|
233
|
+
if not callback:
|
|
234
|
+
raise Exception(f"No hook callback found for ID: {callback_id}")
|
|
235
|
+
|
|
236
|
+
response_data = await callback(
|
|
237
|
+
request_data.get("input"),
|
|
238
|
+
request_data.get("tool_use_id"),
|
|
239
|
+
{"signal": None}, # TODO: Add abort signal support
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
elif subtype == "mcp_message":
|
|
243
|
+
# Handle SDK MCP request
|
|
244
|
+
server_name = request_data.get("server_name")
|
|
245
|
+
mcp_message = request_data.get("message")
|
|
246
|
+
|
|
247
|
+
if not server_name or not mcp_message:
|
|
248
|
+
raise Exception("Missing server_name or message for MCP request")
|
|
249
|
+
|
|
250
|
+
# Type narrowing - we've verified these are not None above
|
|
251
|
+
assert isinstance(server_name, str)
|
|
252
|
+
assert isinstance(mcp_message, dict)
|
|
253
|
+
mcp_response = await self._handle_sdk_mcp_request(
|
|
254
|
+
server_name, mcp_message
|
|
255
|
+
)
|
|
256
|
+
# Wrap the MCP response as expected by the control protocol
|
|
257
|
+
response_data = {"mcp_response": mcp_response}
|
|
258
|
+
|
|
259
|
+
else:
|
|
260
|
+
raise Exception(f"Unsupported control request subtype: {subtype}")
|
|
261
|
+
|
|
262
|
+
# Send success response
|
|
263
|
+
success_response: SDKControlResponse = {
|
|
264
|
+
"type": "control_response",
|
|
265
|
+
"response": {
|
|
266
|
+
"subtype": "success",
|
|
267
|
+
"request_id": request_id,
|
|
268
|
+
"response": response_data,
|
|
269
|
+
},
|
|
270
|
+
}
|
|
271
|
+
await self.transport.write(json.dumps(success_response) + "\n")
|
|
272
|
+
|
|
273
|
+
except Exception as e:
|
|
274
|
+
# Send error response
|
|
275
|
+
error_response: SDKControlResponse = {
|
|
276
|
+
"type": "control_response",
|
|
277
|
+
"response": {
|
|
278
|
+
"subtype": "error",
|
|
279
|
+
"request_id": request_id,
|
|
280
|
+
"error": str(e),
|
|
281
|
+
},
|
|
282
|
+
}
|
|
283
|
+
await self.transport.write(json.dumps(error_response) + "\n")
|
|
284
|
+
|
|
285
|
+
async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
|
286
|
+
"""Send control request to CLI and wait for response."""
|
|
287
|
+
if not self.is_streaming_mode:
|
|
288
|
+
raise Exception("Control requests require streaming mode")
|
|
289
|
+
|
|
290
|
+
# Generate unique request ID
|
|
291
|
+
self._request_counter += 1
|
|
292
|
+
request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}"
|
|
293
|
+
|
|
294
|
+
# Create event for response
|
|
295
|
+
event = anyio.Event()
|
|
296
|
+
self.pending_control_responses[request_id] = event
|
|
297
|
+
|
|
298
|
+
# Build and send request
|
|
299
|
+
control_request = {
|
|
300
|
+
"type": "control_request",
|
|
301
|
+
"request_id": request_id,
|
|
302
|
+
"request": request,
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
await self.transport.write(json.dumps(control_request) + "\n")
|
|
306
|
+
|
|
307
|
+
# Wait for response
|
|
308
|
+
try:
|
|
309
|
+
with anyio.fail_after(60.0):
|
|
310
|
+
await event.wait()
|
|
311
|
+
|
|
312
|
+
result = self.pending_control_results.pop(request_id)
|
|
313
|
+
self.pending_control_responses.pop(request_id, None)
|
|
314
|
+
|
|
315
|
+
if isinstance(result, Exception):
|
|
316
|
+
raise result
|
|
317
|
+
|
|
318
|
+
response_data = result.get("response", {})
|
|
319
|
+
return response_data if isinstance(response_data, dict) else {}
|
|
320
|
+
except TimeoutError as e:
|
|
321
|
+
self.pending_control_responses.pop(request_id, None)
|
|
322
|
+
self.pending_control_results.pop(request_id, None)
|
|
323
|
+
raise Exception(f"Control request timeout: {request.get('subtype')}") from e
|
|
324
|
+
|
|
325
|
+
async def _handle_sdk_mcp_request(
|
|
326
|
+
self, server_name: str, message: dict[str, Any]
|
|
327
|
+
) -> dict[str, Any]:
|
|
328
|
+
"""Handle an MCP request for an SDK server.
|
|
329
|
+
|
|
330
|
+
This acts as a bridge between JSONRPC messages from the CLI
|
|
331
|
+
and the in-process MCP server. Ideally the MCP SDK would provide
|
|
332
|
+
a method to handle raw JSONRPC, but for now we route manually.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
server_name: Name of the SDK MCP server
|
|
336
|
+
message: The JSONRPC message
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
The response message
|
|
340
|
+
"""
|
|
341
|
+
if server_name not in self.sdk_mcp_servers:
|
|
342
|
+
return {
|
|
343
|
+
"jsonrpc": "2.0",
|
|
344
|
+
"id": message.get("id"),
|
|
345
|
+
"error": {
|
|
346
|
+
"code": -32601,
|
|
347
|
+
"message": f"Server '{server_name}' not found",
|
|
348
|
+
},
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
server = self.sdk_mcp_servers[server_name]
|
|
352
|
+
method = message.get("method")
|
|
353
|
+
params = message.get("params", {})
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
# TODO: Python MCP SDK lacks the Transport abstraction that TypeScript has.
|
|
357
|
+
# TypeScript: server.connect(transport) allows custom transports
|
|
358
|
+
# Python: server.run(read_stream, write_stream) requires actual streams
|
|
359
|
+
#
|
|
360
|
+
# This forces us to manually route methods. When Python MCP adds Transport
|
|
361
|
+
# support, we can refactor to match the TypeScript approach.
|
|
362
|
+
if method == "initialize":
|
|
363
|
+
# Handle MCP initialization - hardcoded for tools only, no listChanged
|
|
364
|
+
return {
|
|
365
|
+
"jsonrpc": "2.0",
|
|
366
|
+
"id": message.get("id"),
|
|
367
|
+
"result": {
|
|
368
|
+
"protocolVersion": "2024-11-05",
|
|
369
|
+
"capabilities": {
|
|
370
|
+
"tools": {} # Tools capability without listChanged
|
|
371
|
+
},
|
|
372
|
+
"serverInfo": {
|
|
373
|
+
"name": server.name,
|
|
374
|
+
"version": server.version or "1.0.0",
|
|
375
|
+
},
|
|
376
|
+
},
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
elif method == "tools/list":
|
|
380
|
+
request = ListToolsRequest(method=method)
|
|
381
|
+
handler = server.request_handlers.get(ListToolsRequest)
|
|
382
|
+
if handler:
|
|
383
|
+
result = await handler(request)
|
|
384
|
+
# Convert MCP result to JSONRPC response
|
|
385
|
+
tools_data = [
|
|
386
|
+
{
|
|
387
|
+
"name": tool.name,
|
|
388
|
+
"description": tool.description,
|
|
389
|
+
"inputSchema": (
|
|
390
|
+
tool.inputSchema.model_dump()
|
|
391
|
+
if hasattr(tool.inputSchema, "model_dump")
|
|
392
|
+
else tool.inputSchema
|
|
393
|
+
)
|
|
394
|
+
if tool.inputSchema
|
|
395
|
+
else {},
|
|
396
|
+
}
|
|
397
|
+
for tool in result.root.tools # type: ignore[union-attr]
|
|
398
|
+
]
|
|
399
|
+
return {
|
|
400
|
+
"jsonrpc": "2.0",
|
|
401
|
+
"id": message.get("id"),
|
|
402
|
+
"result": {"tools": tools_data},
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
elif method == "tools/call":
|
|
406
|
+
call_request = CallToolRequest(
|
|
407
|
+
method=method,
|
|
408
|
+
params=CallToolRequestParams(
|
|
409
|
+
name=params.get("name"), arguments=params.get("arguments", {})
|
|
410
|
+
),
|
|
411
|
+
)
|
|
412
|
+
handler = server.request_handlers.get(CallToolRequest)
|
|
413
|
+
if handler:
|
|
414
|
+
result = await handler(call_request)
|
|
415
|
+
# Convert MCP result to JSONRPC response
|
|
416
|
+
content = []
|
|
417
|
+
for item in result.root.content: # type: ignore[union-attr]
|
|
418
|
+
if hasattr(item, "text"):
|
|
419
|
+
content.append({"type": "text", "text": item.text})
|
|
420
|
+
elif hasattr(item, "data") and hasattr(item, "mimeType"):
|
|
421
|
+
content.append(
|
|
422
|
+
{
|
|
423
|
+
"type": "image",
|
|
424
|
+
"data": item.data,
|
|
425
|
+
"mimeType": item.mimeType,
|
|
426
|
+
}
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
response_data = {"content": content}
|
|
430
|
+
if hasattr(result.root, "is_error") and result.root.is_error:
|
|
431
|
+
response_data["is_error"] = True # type: ignore[assignment]
|
|
432
|
+
|
|
433
|
+
return {
|
|
434
|
+
"jsonrpc": "2.0",
|
|
435
|
+
"id": message.get("id"),
|
|
436
|
+
"result": response_data,
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
elif method == "notifications/initialized":
|
|
440
|
+
# Handle initialized notification - just acknowledge it
|
|
441
|
+
return {"jsonrpc": "2.0", "result": {}}
|
|
442
|
+
|
|
443
|
+
# Add more methods here as MCP SDK adds them (resources, prompts, etc.)
|
|
444
|
+
# This is the limitation Ashwin pointed out - we have to manually update
|
|
445
|
+
|
|
446
|
+
return {
|
|
447
|
+
"jsonrpc": "2.0",
|
|
448
|
+
"id": message.get("id"),
|
|
449
|
+
"error": {"code": -32601, "message": f"Method '{method}' not found"},
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
except Exception as e:
|
|
453
|
+
return {
|
|
454
|
+
"jsonrpc": "2.0",
|
|
455
|
+
"id": message.get("id"),
|
|
456
|
+
"error": {"code": -32603, "message": str(e)},
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
async def interrupt(self) -> None:
|
|
460
|
+
"""Send interrupt control request."""
|
|
461
|
+
await self._send_control_request({"subtype": "interrupt"})
|
|
462
|
+
|
|
463
|
+
async def set_permission_mode(self, mode: str) -> None:
|
|
464
|
+
"""Change permission mode."""
|
|
465
|
+
await self._send_control_request(
|
|
466
|
+
{
|
|
467
|
+
"subtype": "set_permission_mode",
|
|
468
|
+
"mode": mode,
|
|
469
|
+
}
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
async def set_model(self, model: str | None) -> None:
|
|
473
|
+
"""Change the AI model."""
|
|
474
|
+
await self._send_control_request(
|
|
475
|
+
{
|
|
476
|
+
"subtype": "set_model",
|
|
477
|
+
"model": model,
|
|
478
|
+
}
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
|
|
482
|
+
"""Stream input messages to transport."""
|
|
483
|
+
try:
|
|
484
|
+
async for message in stream:
|
|
485
|
+
if self._closed:
|
|
486
|
+
break
|
|
487
|
+
await self.transport.write(json.dumps(message) + "\n")
|
|
488
|
+
# After all messages sent, end input
|
|
489
|
+
await self.transport.end_input()
|
|
490
|
+
except Exception as e:
|
|
491
|
+
logger.debug(f"Error streaming input: {e}")
|
|
492
|
+
|
|
493
|
+
async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
|
|
494
|
+
"""Receive SDK messages (not control messages)."""
|
|
495
|
+
async for message in self._message_receive:
|
|
496
|
+
# Check for special messages
|
|
497
|
+
if message.get("type") == "end":
|
|
498
|
+
break
|
|
499
|
+
elif message.get("type") == "error":
|
|
500
|
+
raise Exception(message.get("error", "Unknown error"))
|
|
501
|
+
|
|
502
|
+
yield message
|
|
503
|
+
|
|
504
|
+
async def close(self) -> None:
|
|
505
|
+
"""Close the query and transport."""
|
|
506
|
+
self._closed = True
|
|
507
|
+
if self._tg:
|
|
508
|
+
self._tg.cancel_scope.cancel()
|
|
509
|
+
# Wait for task group to complete cancellation
|
|
510
|
+
with suppress(anyio.get_cancelled_exc_class()):
|
|
511
|
+
await self._tg.__aexit__(None, None, None)
|
|
512
|
+
await self.transport.close()
|
|
513
|
+
|
|
514
|
+
# Make Query an async iterator
|
|
515
|
+
def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
|
|
516
|
+
"""Return async iterator for messages."""
|
|
517
|
+
return self.receive_messages()
|
|
518
|
+
|
|
519
|
+
async def __anext__(self) -> dict[str, Any]:
|
|
520
|
+
"""Get next message."""
|
|
521
|
+
async for message in self.receive_messages():
|
|
522
|
+
return message
|
|
523
|
+
raise StopAsyncIteration
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Transport implementations for Claude SDK."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Transport(ABC):
|
|
9
|
+
"""Abstract transport for Claude communication.
|
|
10
|
+
|
|
11
|
+
WARNING: This internal API is exposed for custom transport implementations
|
|
12
|
+
(e.g., remote Claude Code connections). The Claude Code team may change or
|
|
13
|
+
or remove this abstract class in any future release. Custom implementations
|
|
14
|
+
must be updated to match interface changes.
|
|
15
|
+
|
|
16
|
+
This is a low-level transport interface that handles raw I/O with the Claude
|
|
17
|
+
process or service. The Query class builds on top of this to implement the
|
|
18
|
+
control protocol and message routing.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def connect(self) -> None:
|
|
23
|
+
"""Connect the transport and prepare for communication.
|
|
24
|
+
|
|
25
|
+
For subprocess transports, this starts the process.
|
|
26
|
+
For network transports, this establishes the connection.
|
|
27
|
+
"""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
async def write(self, data: str) -> None:
|
|
32
|
+
"""Write raw data to the transport.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
data: Raw string data to write (typically JSON + newline)
|
|
36
|
+
"""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
|
|
41
|
+
"""Read and parse messages from the transport.
|
|
42
|
+
|
|
43
|
+
Yields:
|
|
44
|
+
Parsed JSON messages from the transport
|
|
45
|
+
"""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
async def close(self) -> None:
|
|
50
|
+
"""Close the transport connection and clean up resources."""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def is_ready(self) -> bool:
|
|
55
|
+
"""Check if transport is ready for communication.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
True if transport is ready to send/receive messages
|
|
59
|
+
"""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
async def end_input(self) -> None:
|
|
64
|
+
"""End the input stream (close stdin for process transports)."""
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
__all__ = ["Transport"]
|