claude-agent-sdk 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of claude-agent-sdk might be problematic. Click here for more details.

@@ -0,0 +1,523 @@
1
+ """Query class for handling bidirectional control protocol."""
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
7
+ from contextlib import suppress
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import anyio
11
+ from mcp.types import (
12
+ CallToolRequest,
13
+ CallToolRequestParams,
14
+ ListToolsRequest,
15
+ )
16
+
17
+ from ..types import (
18
+ PermissionResultAllow,
19
+ PermissionResultDeny,
20
+ SDKControlPermissionRequest,
21
+ SDKControlRequest,
22
+ SDKControlResponse,
23
+ SDKHookCallbackRequest,
24
+ ToolPermissionContext,
25
+ )
26
+ from .transport import Transport
27
+
28
+ if TYPE_CHECKING:
29
+ from mcp.server import Server as McpServer
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class Query:
35
+ """Handles bidirectional control protocol on top of Transport.
36
+
37
+ This class manages:
38
+ - Control request/response routing
39
+ - Hook callbacks
40
+ - Tool permission callbacks
41
+ - Message streaming
42
+ - Initialization handshake
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ transport: Transport,
48
+ is_streaming_mode: bool,
49
+ can_use_tool: Callable[
50
+ [str, dict[str, Any], ToolPermissionContext],
51
+ Awaitable[PermissionResultAllow | PermissionResultDeny],
52
+ ]
53
+ | None = None,
54
+ hooks: dict[str, list[dict[str, Any]]] | None = None,
55
+ sdk_mcp_servers: dict[str, "McpServer"] | None = None,
56
+ ):
57
+ """Initialize Query with transport and callbacks.
58
+
59
+ Args:
60
+ transport: Low-level transport for I/O
61
+ is_streaming_mode: Whether using streaming (bidirectional) mode
62
+ can_use_tool: Optional callback for tool permission requests
63
+ hooks: Optional hook configurations
64
+ sdk_mcp_servers: Optional SDK MCP server instances
65
+ """
66
+ self.transport = transport
67
+ self.is_streaming_mode = is_streaming_mode
68
+ self.can_use_tool = can_use_tool
69
+ self.hooks = hooks or {}
70
+ self.sdk_mcp_servers = sdk_mcp_servers or {}
71
+
72
+ # Control protocol state
73
+ self.pending_control_responses: dict[str, anyio.Event] = {}
74
+ self.pending_control_results: dict[str, dict[str, Any] | Exception] = {}
75
+ self.hook_callbacks: dict[str, Callable[..., Any]] = {}
76
+ self.next_callback_id = 0
77
+ self._request_counter = 0
78
+
79
+ # Message stream
80
+ self._message_send, self._message_receive = anyio.create_memory_object_stream[
81
+ dict[str, Any]
82
+ ](max_buffer_size=100)
83
+ self._tg: anyio.abc.TaskGroup | None = None
84
+ self._initialized = False
85
+ self._closed = False
86
+ self._initialization_result: dict[str, Any] | None = None
87
+
88
+ async def initialize(self) -> dict[str, Any] | None:
89
+ """Initialize control protocol if in streaming mode.
90
+
91
+ Returns:
92
+ Initialize response with supported commands, or None if not streaming
93
+ """
94
+ if not self.is_streaming_mode:
95
+ return None
96
+
97
+ # Build hooks configuration for initialization
98
+ hooks_config: dict[str, Any] = {}
99
+ if self.hooks:
100
+ for event, matchers in self.hooks.items():
101
+ if matchers:
102
+ hooks_config[event] = []
103
+ for matcher in matchers:
104
+ callback_ids = []
105
+ for callback in matcher.get("hooks", []):
106
+ callback_id = f"hook_{self.next_callback_id}"
107
+ self.next_callback_id += 1
108
+ self.hook_callbacks[callback_id] = callback
109
+ callback_ids.append(callback_id)
110
+ hooks_config[event].append(
111
+ {
112
+ "matcher": matcher.get("matcher"),
113
+ "hookCallbackIds": callback_ids,
114
+ }
115
+ )
116
+
117
+ # Send initialize request
118
+ request = {
119
+ "subtype": "initialize",
120
+ "hooks": hooks_config if hooks_config else None,
121
+ }
122
+
123
+ response = await self._send_control_request(request)
124
+ self._initialized = True
125
+ self._initialization_result = response # Store for later access
126
+ return response
127
+
128
+ async def start(self) -> None:
129
+ """Start reading messages from transport."""
130
+ if self._tg is None:
131
+ self._tg = anyio.create_task_group()
132
+ await self._tg.__aenter__()
133
+ self._tg.start_soon(self._read_messages)
134
+
135
+ async def _read_messages(self) -> None:
136
+ """Read messages from transport and route them."""
137
+ try:
138
+ async for message in self.transport.read_messages():
139
+ if self._closed:
140
+ break
141
+
142
+ msg_type = message.get("type")
143
+
144
+ # Route control messages
145
+ if msg_type == "control_response":
146
+ response = message.get("response", {})
147
+ request_id = response.get("request_id")
148
+ if request_id in self.pending_control_responses:
149
+ event = self.pending_control_responses[request_id]
150
+ if response.get("subtype") == "error":
151
+ self.pending_control_results[request_id] = Exception(
152
+ response.get("error", "Unknown error")
153
+ )
154
+ else:
155
+ self.pending_control_results[request_id] = response
156
+ event.set()
157
+ continue
158
+
159
+ elif msg_type == "control_request":
160
+ # Handle incoming control requests from CLI
161
+ # Cast message to SDKControlRequest for type safety
162
+ request: SDKControlRequest = message # type: ignore[assignment]
163
+ if self._tg:
164
+ self._tg.start_soon(self._handle_control_request, request)
165
+ continue
166
+
167
+ elif msg_type == "control_cancel_request":
168
+ # Handle cancel requests
169
+ # TODO: Implement cancellation support
170
+ continue
171
+
172
+ # Regular SDK messages go to the stream
173
+ await self._message_send.send(message)
174
+
175
+ except anyio.get_cancelled_exc_class():
176
+ # Task was cancelled - this is expected behavior
177
+ logger.debug("Read task cancelled")
178
+ raise # Re-raise to properly handle cancellation
179
+ except Exception as e:
180
+ logger.error(f"Fatal error in message reader: {e}")
181
+ # Put error in stream so iterators can handle it
182
+ await self._message_send.send({"type": "error", "error": str(e)})
183
+ finally:
184
+ # Always signal end of stream
185
+ await self._message_send.send({"type": "end"})
186
+
187
+ async def _handle_control_request(self, request: SDKControlRequest) -> None:
188
+ """Handle incoming control request from CLI."""
189
+ request_id = request["request_id"]
190
+ request_data = request["request"]
191
+ subtype = request_data["subtype"]
192
+
193
+ try:
194
+ response_data: dict[str, Any] = {}
195
+
196
+ if subtype == "can_use_tool":
197
+ permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment]
198
+ # Handle tool permission request
199
+ if not self.can_use_tool:
200
+ raise Exception("canUseTool callback is not provided")
201
+
202
+ context = ToolPermissionContext(
203
+ signal=None, # TODO: Add abort signal support
204
+ suggestions=permission_request.get("permission_suggestions", [])
205
+ or [],
206
+ )
207
+
208
+ response = await self.can_use_tool(
209
+ permission_request["tool_name"],
210
+ permission_request["input"],
211
+ context,
212
+ )
213
+
214
+ # Convert PermissionResult to expected dict format
215
+ if isinstance(response, PermissionResultAllow):
216
+ response_data = {"allow": True}
217
+ if response.updated_input is not None:
218
+ response_data["input"] = response.updated_input
219
+ # TODO: Handle updatedPermissions when control protocol supports it
220
+ elif isinstance(response, PermissionResultDeny):
221
+ response_data = {"allow": False, "reason": response.message}
222
+ # TODO: Handle interrupt flag when control protocol supports it
223
+ else:
224
+ raise TypeError(
225
+ f"Tool permission callback must return PermissionResult (PermissionResultAllow or PermissionResultDeny), got {type(response)}"
226
+ )
227
+
228
+ elif subtype == "hook_callback":
229
+ hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment]
230
+ # Handle hook callback
231
+ callback_id = hook_callback_request["callback_id"]
232
+ callback = self.hook_callbacks.get(callback_id)
233
+ if not callback:
234
+ raise Exception(f"No hook callback found for ID: {callback_id}")
235
+
236
+ response_data = await callback(
237
+ request_data.get("input"),
238
+ request_data.get("tool_use_id"),
239
+ {"signal": None}, # TODO: Add abort signal support
240
+ )
241
+
242
+ elif subtype == "mcp_message":
243
+ # Handle SDK MCP request
244
+ server_name = request_data.get("server_name")
245
+ mcp_message = request_data.get("message")
246
+
247
+ if not server_name or not mcp_message:
248
+ raise Exception("Missing server_name or message for MCP request")
249
+
250
+ # Type narrowing - we've verified these are not None above
251
+ assert isinstance(server_name, str)
252
+ assert isinstance(mcp_message, dict)
253
+ mcp_response = await self._handle_sdk_mcp_request(
254
+ server_name, mcp_message
255
+ )
256
+ # Wrap the MCP response as expected by the control protocol
257
+ response_data = {"mcp_response": mcp_response}
258
+
259
+ else:
260
+ raise Exception(f"Unsupported control request subtype: {subtype}")
261
+
262
+ # Send success response
263
+ success_response: SDKControlResponse = {
264
+ "type": "control_response",
265
+ "response": {
266
+ "subtype": "success",
267
+ "request_id": request_id,
268
+ "response": response_data,
269
+ },
270
+ }
271
+ await self.transport.write(json.dumps(success_response) + "\n")
272
+
273
+ except Exception as e:
274
+ # Send error response
275
+ error_response: SDKControlResponse = {
276
+ "type": "control_response",
277
+ "response": {
278
+ "subtype": "error",
279
+ "request_id": request_id,
280
+ "error": str(e),
281
+ },
282
+ }
283
+ await self.transport.write(json.dumps(error_response) + "\n")
284
+
285
+ async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]:
286
+ """Send control request to CLI and wait for response."""
287
+ if not self.is_streaming_mode:
288
+ raise Exception("Control requests require streaming mode")
289
+
290
+ # Generate unique request ID
291
+ self._request_counter += 1
292
+ request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}"
293
+
294
+ # Create event for response
295
+ event = anyio.Event()
296
+ self.pending_control_responses[request_id] = event
297
+
298
+ # Build and send request
299
+ control_request = {
300
+ "type": "control_request",
301
+ "request_id": request_id,
302
+ "request": request,
303
+ }
304
+
305
+ await self.transport.write(json.dumps(control_request) + "\n")
306
+
307
+ # Wait for response
308
+ try:
309
+ with anyio.fail_after(60.0):
310
+ await event.wait()
311
+
312
+ result = self.pending_control_results.pop(request_id)
313
+ self.pending_control_responses.pop(request_id, None)
314
+
315
+ if isinstance(result, Exception):
316
+ raise result
317
+
318
+ response_data = result.get("response", {})
319
+ return response_data if isinstance(response_data, dict) else {}
320
+ except TimeoutError as e:
321
+ self.pending_control_responses.pop(request_id, None)
322
+ self.pending_control_results.pop(request_id, None)
323
+ raise Exception(f"Control request timeout: {request.get('subtype')}") from e
324
+
325
+ async def _handle_sdk_mcp_request(
326
+ self, server_name: str, message: dict[str, Any]
327
+ ) -> dict[str, Any]:
328
+ """Handle an MCP request for an SDK server.
329
+
330
+ This acts as a bridge between JSONRPC messages from the CLI
331
+ and the in-process MCP server. Ideally the MCP SDK would provide
332
+ a method to handle raw JSONRPC, but for now we route manually.
333
+
334
+ Args:
335
+ server_name: Name of the SDK MCP server
336
+ message: The JSONRPC message
337
+
338
+ Returns:
339
+ The response message
340
+ """
341
+ if server_name not in self.sdk_mcp_servers:
342
+ return {
343
+ "jsonrpc": "2.0",
344
+ "id": message.get("id"),
345
+ "error": {
346
+ "code": -32601,
347
+ "message": f"Server '{server_name}' not found",
348
+ },
349
+ }
350
+
351
+ server = self.sdk_mcp_servers[server_name]
352
+ method = message.get("method")
353
+ params = message.get("params", {})
354
+
355
+ try:
356
+ # TODO: Python MCP SDK lacks the Transport abstraction that TypeScript has.
357
+ # TypeScript: server.connect(transport) allows custom transports
358
+ # Python: server.run(read_stream, write_stream) requires actual streams
359
+ #
360
+ # This forces us to manually route methods. When Python MCP adds Transport
361
+ # support, we can refactor to match the TypeScript approach.
362
+ if method == "initialize":
363
+ # Handle MCP initialization - hardcoded for tools only, no listChanged
364
+ return {
365
+ "jsonrpc": "2.0",
366
+ "id": message.get("id"),
367
+ "result": {
368
+ "protocolVersion": "2024-11-05",
369
+ "capabilities": {
370
+ "tools": {} # Tools capability without listChanged
371
+ },
372
+ "serverInfo": {
373
+ "name": server.name,
374
+ "version": server.version or "1.0.0",
375
+ },
376
+ },
377
+ }
378
+
379
+ elif method == "tools/list":
380
+ request = ListToolsRequest(method=method)
381
+ handler = server.request_handlers.get(ListToolsRequest)
382
+ if handler:
383
+ result = await handler(request)
384
+ # Convert MCP result to JSONRPC response
385
+ tools_data = [
386
+ {
387
+ "name": tool.name,
388
+ "description": tool.description,
389
+ "inputSchema": (
390
+ tool.inputSchema.model_dump()
391
+ if hasattr(tool.inputSchema, "model_dump")
392
+ else tool.inputSchema
393
+ )
394
+ if tool.inputSchema
395
+ else {},
396
+ }
397
+ for tool in result.root.tools # type: ignore[union-attr]
398
+ ]
399
+ return {
400
+ "jsonrpc": "2.0",
401
+ "id": message.get("id"),
402
+ "result": {"tools": tools_data},
403
+ }
404
+
405
+ elif method == "tools/call":
406
+ call_request = CallToolRequest(
407
+ method=method,
408
+ params=CallToolRequestParams(
409
+ name=params.get("name"), arguments=params.get("arguments", {})
410
+ ),
411
+ )
412
+ handler = server.request_handlers.get(CallToolRequest)
413
+ if handler:
414
+ result = await handler(call_request)
415
+ # Convert MCP result to JSONRPC response
416
+ content = []
417
+ for item in result.root.content: # type: ignore[union-attr]
418
+ if hasattr(item, "text"):
419
+ content.append({"type": "text", "text": item.text})
420
+ elif hasattr(item, "data") and hasattr(item, "mimeType"):
421
+ content.append(
422
+ {
423
+ "type": "image",
424
+ "data": item.data,
425
+ "mimeType": item.mimeType,
426
+ }
427
+ )
428
+
429
+ response_data = {"content": content}
430
+ if hasattr(result.root, "is_error") and result.root.is_error:
431
+ response_data["is_error"] = True # type: ignore[assignment]
432
+
433
+ return {
434
+ "jsonrpc": "2.0",
435
+ "id": message.get("id"),
436
+ "result": response_data,
437
+ }
438
+
439
+ elif method == "notifications/initialized":
440
+ # Handle initialized notification - just acknowledge it
441
+ return {"jsonrpc": "2.0", "result": {}}
442
+
443
+ # Add more methods here as MCP SDK adds them (resources, prompts, etc.)
444
+ # This is the limitation Ashwin pointed out - we have to manually update
445
+
446
+ return {
447
+ "jsonrpc": "2.0",
448
+ "id": message.get("id"),
449
+ "error": {"code": -32601, "message": f"Method '{method}' not found"},
450
+ }
451
+
452
+ except Exception as e:
453
+ return {
454
+ "jsonrpc": "2.0",
455
+ "id": message.get("id"),
456
+ "error": {"code": -32603, "message": str(e)},
457
+ }
458
+
459
+ async def interrupt(self) -> None:
460
+ """Send interrupt control request."""
461
+ await self._send_control_request({"subtype": "interrupt"})
462
+
463
+ async def set_permission_mode(self, mode: str) -> None:
464
+ """Change permission mode."""
465
+ await self._send_control_request(
466
+ {
467
+ "subtype": "set_permission_mode",
468
+ "mode": mode,
469
+ }
470
+ )
471
+
472
+ async def set_model(self, model: str | None) -> None:
473
+ """Change the AI model."""
474
+ await self._send_control_request(
475
+ {
476
+ "subtype": "set_model",
477
+ "model": model,
478
+ }
479
+ )
480
+
481
+ async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
482
+ """Stream input messages to transport."""
483
+ try:
484
+ async for message in stream:
485
+ if self._closed:
486
+ break
487
+ await self.transport.write(json.dumps(message) + "\n")
488
+ # After all messages sent, end input
489
+ await self.transport.end_input()
490
+ except Exception as e:
491
+ logger.debug(f"Error streaming input: {e}")
492
+
493
+ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
494
+ """Receive SDK messages (not control messages)."""
495
+ async for message in self._message_receive:
496
+ # Check for special messages
497
+ if message.get("type") == "end":
498
+ break
499
+ elif message.get("type") == "error":
500
+ raise Exception(message.get("error", "Unknown error"))
501
+
502
+ yield message
503
+
504
+ async def close(self) -> None:
505
+ """Close the query and transport."""
506
+ self._closed = True
507
+ if self._tg:
508
+ self._tg.cancel_scope.cancel()
509
+ # Wait for task group to complete cancellation
510
+ with suppress(anyio.get_cancelled_exc_class()):
511
+ await self._tg.__aexit__(None, None, None)
512
+ await self.transport.close()
513
+
514
+ # Make Query an async iterator
515
+ def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
516
+ """Return async iterator for messages."""
517
+ return self.receive_messages()
518
+
519
+ async def __anext__(self) -> dict[str, Any]:
520
+ """Get next message."""
521
+ async for message in self.receive_messages():
522
+ return message
523
+ raise StopAsyncIteration
@@ -0,0 +1,68 @@
1
+ """Transport implementations for Claude SDK."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import AsyncIterator
5
+ from typing import Any
6
+
7
+
8
+ class Transport(ABC):
9
+ """Abstract transport for Claude communication.
10
+
11
+ WARNING: This internal API is exposed for custom transport implementations
12
+ (e.g., remote Claude Code connections). The Claude Code team may change or
13
+ or remove this abstract class in any future release. Custom implementations
14
+ must be updated to match interface changes.
15
+
16
+ This is a low-level transport interface that handles raw I/O with the Claude
17
+ process or service. The Query class builds on top of this to implement the
18
+ control protocol and message routing.
19
+ """
20
+
21
+ @abstractmethod
22
+ async def connect(self) -> None:
23
+ """Connect the transport and prepare for communication.
24
+
25
+ For subprocess transports, this starts the process.
26
+ For network transports, this establishes the connection.
27
+ """
28
+ pass
29
+
30
+ @abstractmethod
31
+ async def write(self, data: str) -> None:
32
+ """Write raw data to the transport.
33
+
34
+ Args:
35
+ data: Raw string data to write (typically JSON + newline)
36
+ """
37
+ pass
38
+
39
+ @abstractmethod
40
+ def read_messages(self) -> AsyncIterator[dict[str, Any]]:
41
+ """Read and parse messages from the transport.
42
+
43
+ Yields:
44
+ Parsed JSON messages from the transport
45
+ """
46
+ pass
47
+
48
+ @abstractmethod
49
+ async def close(self) -> None:
50
+ """Close the transport connection and clean up resources."""
51
+ pass
52
+
53
+ @abstractmethod
54
+ def is_ready(self) -> bool:
55
+ """Check if transport is ready for communication.
56
+
57
+ Returns:
58
+ True if transport is ready to send/receive messages
59
+ """
60
+ pass
61
+
62
+ @abstractmethod
63
+ async def end_input(self) -> None:
64
+ """End the input stream (close stdin for process transports)."""
65
+ pass
66
+
67
+
68
+ __all__ = ["Transport"]