ccproxy-api 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. ccproxy/_version.py +2 -2
  2. ccproxy/adapters/codex/__init__.py +11 -0
  3. ccproxy/adapters/openai/adapter.py +1 -1
  4. ccproxy/adapters/openai/models.py +1 -1
  5. ccproxy/adapters/openai/response_adapter.py +355 -0
  6. ccproxy/adapters/openai/response_models.py +178 -0
  7. ccproxy/adapters/openai/streaming.py +1 -0
  8. ccproxy/api/app.py +150 -224
  9. ccproxy/api/dependencies.py +22 -2
  10. ccproxy/api/middleware/errors.py +27 -3
  11. ccproxy/api/middleware/logging.py +4 -0
  12. ccproxy/api/responses.py +6 -1
  13. ccproxy/api/routes/claude.py +222 -17
  14. ccproxy/api/routes/codex.py +1231 -0
  15. ccproxy/api/routes/health.py +228 -3
  16. ccproxy/api/routes/proxy.py +25 -6
  17. ccproxy/api/services/permission_service.py +2 -2
  18. ccproxy/auth/openai/__init__.py +13 -0
  19. ccproxy/auth/openai/credentials.py +166 -0
  20. ccproxy/auth/openai/oauth_client.py +334 -0
  21. ccproxy/auth/openai/storage.py +184 -0
  22. ccproxy/claude_sdk/__init__.py +4 -8
  23. ccproxy/claude_sdk/client.py +661 -131
  24. ccproxy/claude_sdk/exceptions.py +16 -0
  25. ccproxy/claude_sdk/manager.py +219 -0
  26. ccproxy/claude_sdk/message_queue.py +342 -0
  27. ccproxy/claude_sdk/options.py +6 -1
  28. ccproxy/claude_sdk/session_client.py +546 -0
  29. ccproxy/claude_sdk/session_pool.py +550 -0
  30. ccproxy/claude_sdk/stream_handle.py +538 -0
  31. ccproxy/claude_sdk/stream_worker.py +392 -0
  32. ccproxy/claude_sdk/streaming.py +53 -11
  33. ccproxy/cli/commands/auth.py +398 -1
  34. ccproxy/cli/commands/serve.py +99 -1
  35. ccproxy/cli/options/claude_options.py +47 -0
  36. ccproxy/config/__init__.py +0 -3
  37. ccproxy/config/claude.py +171 -23
  38. ccproxy/config/codex.py +100 -0
  39. ccproxy/config/discovery.py +10 -1
  40. ccproxy/config/scheduler.py +2 -2
  41. ccproxy/config/settings.py +38 -1
  42. ccproxy/core/codex_transformers.py +389 -0
  43. ccproxy/core/http_transformers.py +458 -75
  44. ccproxy/core/logging.py +108 -12
  45. ccproxy/core/transformers.py +5 -0
  46. ccproxy/models/claude_sdk.py +57 -0
  47. ccproxy/models/detection.py +208 -0
  48. ccproxy/models/requests.py +22 -0
  49. ccproxy/models/responses.py +16 -0
  50. ccproxy/observability/access_logger.py +72 -14
  51. ccproxy/observability/metrics.py +151 -0
  52. ccproxy/observability/storage/duckdb_simple.py +12 -0
  53. ccproxy/observability/storage/models.py +16 -0
  54. ccproxy/observability/streaming_response.py +107 -0
  55. ccproxy/scheduler/manager.py +31 -6
  56. ccproxy/scheduler/tasks.py +122 -0
  57. ccproxy/services/claude_detection_service.py +269 -0
  58. ccproxy/services/claude_sdk_service.py +333 -130
  59. ccproxy/services/codex_detection_service.py +263 -0
  60. ccproxy/services/proxy_service.py +618 -197
  61. ccproxy/utils/__init__.py +9 -1
  62. ccproxy/utils/disconnection_monitor.py +83 -0
  63. ccproxy/utils/id_generator.py +12 -0
  64. ccproxy/utils/model_mapping.py +7 -5
  65. ccproxy/utils/startup_helpers.py +470 -0
  66. ccproxy_api-0.1.6.dist-info/METADATA +615 -0
  67. {ccproxy_api-0.1.4.dist-info → ccproxy_api-0.1.6.dist-info}/RECORD +70 -47
  68. ccproxy/config/loader.py +0 -105
  69. ccproxy_api-0.1.4.dist-info/METADATA +0 -369
  70. {ccproxy_api-0.1.4.dist-info → ccproxy_api-0.1.6.dist-info}/WHEEL +0 -0
  71. {ccproxy_api-0.1.4.dist-info → ccproxy_api-0.1.6.dist-info}/entry_points.txt +0 -0
  72. {ccproxy_api-0.1.4.dist-info → ccproxy_api-0.1.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,21 @@
1
1
  """Claude SDK client wrapper for handling core Claude Code SDK interactions."""
2
2
 
3
+ import asyncio
4
+ import contextlib
3
5
  from collections.abc import AsyncIterator
4
- from typing import Any
6
+ from typing import Any, TypeVar, cast
5
7
 
6
8
  import structlog
7
9
  from pydantic import BaseModel
8
10
 
11
+ from ccproxy.claude_sdk.exceptions import ClaudeSDKError, StreamTimeoutError
12
+ from ccproxy.claude_sdk.manager import SessionManager
13
+ from ccproxy.claude_sdk.stream_handle import StreamHandle
14
+ from ccproxy.config.settings import Settings
9
15
  from ccproxy.core.async_utils import patched_typing
10
16
  from ccproxy.core.errors import ClaudeProxyError, ServiceUnavailableError
11
17
  from ccproxy.models import claude_sdk as sdk_models
18
+ from ccproxy.models.claude_sdk import SDKMessage
12
19
  from ccproxy.observability import timed_operation
13
20
 
14
21
 
@@ -22,7 +29,9 @@ with patched_typing():
22
29
  CLIJSONDecodeError,
23
30
  CLINotFoundError,
24
31
  ProcessError,
25
- query,
32
+ )
33
+ from claude_code_sdk import (
34
+ ClaudeSDKClient as ImportedClaudeSDKClient,
26
35
  )
27
36
  from claude_code_sdk import (
28
37
  ResultMessage as SDKResultMessage,
@@ -37,17 +46,7 @@ with patched_typing():
37
46
 
38
47
  logger = structlog.get_logger(__name__)
39
48
 
40
-
41
- class ClaudeSDKError(Exception):
42
- """Base exception for Claude SDK errors."""
43
-
44
-
45
- class ClaudeSDKConnectionError(ClaudeSDKError):
46
- """Raised when unable to connect to Claude Code."""
47
-
48
-
49
- class ClaudeSDKProcessError(ClaudeSDKError):
50
- """Raised when Claude Code process fails."""
49
+ T = TypeVar("T", bound=BaseModel)
51
50
 
52
51
 
53
52
  class ClaudeSDKClient:
@@ -55,161 +54,676 @@ class ClaudeSDKClient:
55
54
  Minimal Claude SDK client wrapper that handles core SDK interactions.
56
55
 
57
56
  This class provides a clean interface to the Claude Code SDK while handling
58
- error translation and basic query execution.
57
+ error translation and basic query execution. Supports both stateless query()
58
+ calls and pooled connection reuse for improved performance.
59
59
  """
60
60
 
61
- def __init__(self) -> None:
62
- """Initialize the Claude SDK client."""
63
- self._last_api_call_time_ms: float = 0.0
61
+ # Class constants
62
+ FIRST_CHUNK_TIMEOUT = 4.0 # Standard timeout for all streaming methods
63
+ MESSAGE_TYPE_MAP: dict[type[Any], type[BaseModel]] = {
64
+ SDKUserMessage: sdk_models.UserMessage,
65
+ SDKAssistantMessage: sdk_models.AssistantMessage,
66
+ SDKSystemMessage: sdk_models.SystemMessage,
67
+ SDKResultMessage: sdk_models.ResultMessage,
68
+ }
69
+
70
+ def __init__(
71
+ self,
72
+ settings: Settings | None = None,
73
+ session_manager: SessionManager | None = None,
74
+ ) -> None:
75
+ """Initialize the Claude SDK client.
64
76
 
65
- async def query_completion(
66
- self, prompt: str, options: ClaudeCodeOptions, request_id: str | None = None
77
+ Args:
78
+ settings: Application settings for session pool configuration
79
+ session_manager: Optional SessionManager instance for dependency injection
80
+ """
81
+ self._last_api_call_time_ms: float = 0.0
82
+ self._settings = settings
83
+ self._session_manager = session_manager
84
+
85
+ @contextlib.asynccontextmanager
86
+ async def _handle_sdk_exceptions(
87
+ self, operation: str, request_id: str | None = None
88
+ ) -> AsyncIterator[None]:
89
+ """Context manager for common SDK error handling."""
90
+ try:
91
+ yield
92
+ except (CLINotFoundError, CLIConnectionError) as e:
93
+ logger.error(
94
+ "claude_sdk_connection_failed",
95
+ error=str(e),
96
+ error_type=type(e).__name__,
97
+ operation=operation,
98
+ request_id=request_id,
99
+ )
100
+ raise ServiceUnavailableError(f"Claude CLI not available: {str(e)}") from e
101
+ except (ProcessError, CLIJSONDecodeError) as e:
102
+ logger.error(
103
+ "claude_sdk_process_failed",
104
+ error=str(e),
105
+ error_type=type(e).__name__,
106
+ operation=operation,
107
+ request_id=request_id,
108
+ )
109
+ raise ClaudeProxyError(
110
+ message=f"Claude process error: {str(e)}",
111
+ error_type="service_unavailable_error",
112
+ status_code=503,
113
+ ) from e
114
+ except StreamTimeoutError:
115
+ # Re-raise StreamTimeoutError for service layer to handle
116
+ raise
117
+ except Exception as e:
118
+ logger.error(
119
+ "claude_sdk_unexpected_error",
120
+ error=str(e),
121
+ error_type=type(e).__name__,
122
+ operation=operation,
123
+ request_id=request_id,
124
+ )
125
+ raise ClaudeProxyError(
126
+ message=f"Unexpected error: {str(e)}",
127
+ error_type="internal_server_error",
128
+ status_code=500,
129
+ ) from e
130
+
131
+ async def _execute_with_client(
132
+ self,
133
+ client: ImportedClaudeSDKClient, # Claude SDK client (ImportedClaudeSDKClient)
134
+ message: SDKMessage,
135
+ session_id: str | None,
136
+ request_id: str | None,
137
+ session_client: Any = None, # SessionClient for session pool
67
138
  ) -> AsyncIterator[
68
139
  sdk_models.UserMessage
69
140
  | sdk_models.AssistantMessage
70
141
  | sdk_models.SystemMessage
71
142
  | sdk_models.ResultMessage
72
143
  ]:
144
+ """Execute query with standard 4-second first chunk timeout."""
145
+ # Send message
146
+ message_dict = message.model_dump()
147
+ logger.debug("sending_sdk_message", message=message_dict)
148
+
149
+ async def message_iter() -> AsyncIterator[dict[str, Any]]:
150
+ yield message_dict
151
+
152
+ if session_id:
153
+ await client.query(message_iter(), session_id=session_id)
154
+ else:
155
+ await client.query(message_iter())
156
+
157
+ # Get response with 4s timeout on first chunk
158
+ response_iterator = client.receive_response()
159
+ first_message, remaining_iterator = await self._wait_for_first_chunk(
160
+ response_iterator,
161
+ self.FIRST_CHUNK_TIMEOUT, # 4 seconds for all methods
162
+ session_id,
163
+ request_id,
164
+ )
165
+
166
+ # Chain first message with remaining
167
+ async def message_chain() -> AsyncIterator[Any]:
168
+ yield first_message
169
+ async for msg in remaining_iterator:
170
+ yield msg
171
+
172
+ # Process messages
173
+ async for converted_message in self._process_message_stream(
174
+ message_chain(), request_id, session_id, session_client
175
+ ):
176
+ yield converted_message
177
+
178
+ def _convert_anthropic_messages_to_sdk(
179
+ self, messages: list[dict[str, Any]]
180
+ ) -> list[sdk_models.UserMessage]:
181
+ """Convert Anthropic API messages to Claude SDK UserMessage format.
182
+
183
+ Args:
184
+ messages: List of Anthropic API messages
185
+
186
+ Returns:
187
+ List of Claude SDK UserMessage objects
73
188
  """
74
- Execute a query using the Claude Code SDK and yields strongly-typed Pydantic models.
189
+ sdk_messages = []
190
+
191
+ for msg in messages:
192
+ if msg.get("role") == "user":
193
+ # Convert content to SDK format
194
+ content_blocks: list[sdk_models.ContentBlock] = []
195
+
196
+ if isinstance(msg.get("content"), str):
197
+ # Simple text content
198
+ content_blocks.append(
199
+ sdk_models.TextBlock(type="text", text=msg["content"])
200
+ )
201
+ elif isinstance(msg.get("content"), list):
202
+ # List of content blocks
203
+ for block in msg["content"]:
204
+ if isinstance(block, dict):
205
+ if block.get("type") == "text":
206
+ content_blocks.append(
207
+ sdk_models.TextBlock(
208
+ type="text", text=block.get("text", "")
209
+ )
210
+ )
211
+ elif block.get("type") == "tool_result":
212
+ content_blocks.append(
213
+ sdk_models.ToolResultBlock(
214
+ type="tool_result",
215
+ tool_use_id=block.get("tool_use_id", ""),
216
+ content=block.get("content"),
217
+ is_error=block.get("is_error", False),
218
+ )
219
+ )
220
+ # Add other block types as needed
221
+
222
+ if content_blocks:
223
+ sdk_messages.append(sdk_models.UserMessage(content=content_blocks))
224
+
225
+ return sdk_messages
226
+
227
+ def _should_use_session_pool(self, session_id: str | None) -> bool:
228
+ """Determine if session pool should be used for this request."""
229
+ if not session_id or not self._session_manager:
230
+ return False
231
+
232
+ # Check settings using safe attribute chaining
233
+ if not self._settings:
234
+ return False
235
+
236
+ claude_settings = getattr(self._settings, "claude", None)
237
+ if not claude_settings:
238
+ return False
239
+
240
+ pool_settings = getattr(claude_settings, "sdk_session_pool", None)
241
+ if not pool_settings:
242
+ return False
243
+
244
+ return bool(getattr(pool_settings, "enabled", False))
245
+
246
+ async def query_completion(
247
+ self,
248
+ message: SDKMessage,
249
+ options: ClaudeCodeOptions,
250
+ request_id: str | None = None,
251
+ session_id: str | None = None,
252
+ ) -> StreamHandle:
253
+ """
254
+ Execute a query using the Claude Code SDK and return a StreamHandle.
75
255
 
76
256
  Args:
77
- prompt: The prompt string to send to Claude
257
+ message: SDKMessage to send to Claude SDK
78
258
  options: Claude Code options configuration
79
259
  request_id: Optional request ID for correlation
260
+ session_id: Optional session ID for conversation continuity
80
261
 
81
- Yields:
82
- Strongly-typed Pydantic messages from ccproxy.claude_sdk.models
262
+ Returns:
263
+ StreamHandle that can create listeners for the stream
83
264
 
84
265
  Raises:
85
266
  ClaudeSDKError: If the query fails
86
267
  """
87
- async with timed_operation("claude_sdk_query", request_id) as op:
268
+ # Determine routing strategy
269
+ if self._should_use_session_pool(session_id):
270
+ return await self._create_session_pool_stream_handle(
271
+ message, options, request_id, session_id
272
+ )
273
+ else:
274
+ return await self._create_direct_stream_handle(
275
+ message, options, request_id, session_id
276
+ )
277
+
278
+ async def _create_direct_stream_handle(
279
+ self,
280
+ message: SDKMessage,
281
+ options: ClaudeCodeOptions,
282
+ request_id: str | None = None,
283
+ session_id: str | None = None,
284
+ ) -> StreamHandle:
285
+ """Create stream handle for direct query (no session pool)."""
286
+ message_iterator = self._query(message, options, request_id, session_id)
287
+
288
+ return StreamHandle(
289
+ message_iterator=message_iterator,
290
+ session_id=session_id,
291
+ request_id=request_id,
292
+ session_client=None,
293
+ session_config=self._settings.claude.sdk_session_pool
294
+ if self._settings
295
+ else None, # StreamHandle will use defaults
296
+ )
297
+
298
+ async def _create_session_pool_stream_handle(
299
+ self,
300
+ message: SDKMessage,
301
+ options: ClaudeCodeOptions,
302
+ request_id: str | None = None,
303
+ session_id: str | None = None,
304
+ ) -> StreamHandle:
305
+ """Create stream handle for session pool query."""
306
+ if not session_id:
307
+ raise ClaudeSDKError("Session ID required for session pool")
308
+ if not self._session_manager:
309
+ raise ClaudeSDKError("No session manager available")
310
+
311
+ # Enable continue conversation for session pool
312
+ options.continue_conversation = True
313
+ session_client = await self._session_manager.get_session_client(
314
+ session_id, options
315
+ )
316
+
317
+ message_iterator = self._query_with_session_pool(
318
+ message, options, request_id, session_id
319
+ )
320
+
321
+ # Get session config from session manager
322
+ session_config = None
323
+ if (
324
+ self._session_manager
325
+ and hasattr(self._session_manager, "_session_pool")
326
+ and self._session_manager._session_pool
327
+ ):
328
+ session_config = self._session_manager._session_pool.config
329
+
330
+ stream_handle = StreamHandle(
331
+ message_iterator=message_iterator,
332
+ session_id=session_id,
333
+ request_id=request_id,
334
+ session_client=session_client,
335
+ session_config=session_config,
336
+ )
337
+
338
+ # Set the active stream handle on the session client for proper cleanup
339
+ session_client.active_stream_handle = stream_handle
340
+
341
+ return stream_handle
342
+
343
+ async def _query(
344
+ self,
345
+ message: SDKMessage,
346
+ options: ClaudeCodeOptions,
347
+ request_id: str | None = None,
348
+ session_id: str | None = None,
349
+ ) -> AsyncIterator[
350
+ sdk_models.UserMessage
351
+ | sdk_models.AssistantMessage
352
+ | sdk_models.SystemMessage
353
+ | sdk_models.ResultMessage
354
+ ]:
355
+ """Execute query using direct connection (no pool)."""
356
+ async with (
357
+ timed_operation("claude_sdk_query_direct", request_id) as op,
358
+ self._handle_sdk_exceptions("direct_query", request_id),
359
+ ):
360
+ client = ImportedClaudeSDKClient(options)
88
361
  try:
89
- logger.debug("claude_sdk_query_start", prompt_length=len(prompt))
362
+ await client.connect()
90
363
 
91
364
  message_count = 0
92
- async for message in query(prompt=prompt, options=options):
365
+ async for msg in self._execute_with_client(
366
+ client, message, session_id, request_id
367
+ ):
93
368
  message_count += 1
369
+ yield msg
94
370
 
95
- logger.debug(
96
- "claude_sdk_raw_message_received",
97
- message_type=type(message).__name__,
98
- message_count=message_count,
371
+ op["message_count"] = message_count
372
+ self._last_api_call_time_ms = op.get("duration_ms", 0.0)
373
+
374
+ finally:
375
+ # Critical: Always disconnect non-session clients to prevent reuse
376
+ try:
377
+ await client.disconnect()
378
+ except Exception as e:
379
+ logger.warning(
380
+ "claude_sdk_disconnect_failed",
381
+ error=str(e),
99
382
  request_id=request_id,
100
- has_content=hasattr(message, "content")
101
- and bool(getattr(message, "content", None)),
102
- content_preview=str(message)[:150],
103
383
  )
104
384
 
105
- model_class: type[BaseModel] | None = None
106
- if isinstance(message, SDKUserMessage):
107
- model_class = sdk_models.UserMessage
108
- elif isinstance(message, SDKAssistantMessage):
109
- model_class = sdk_models.AssistantMessage
110
- elif isinstance(message, SDKSystemMessage):
111
- model_class = sdk_models.SystemMessage
112
- elif isinstance(message, SDKResultMessage):
113
- model_class = sdk_models.ResultMessage
114
-
115
- # Convert Claude SDK message to our Pydantic model
116
- try:
117
- if hasattr(message, "__dict__"):
118
- converted_message = model_class.model_validate(
119
- vars(message)
120
- )
121
- else:
122
- # For dataclass objects, use dataclass.asdict equivalent
123
- message_dict = {}
124
- if hasattr(message, "__dataclass_fields__"):
125
- message_dict = {
126
- field: getattr(message, field)
127
- for field in message.__dataclass_fields__
128
- }
129
- else:
130
- # Try to extract common attributes
131
- for attr in [
132
- "content",
133
- "subtype",
134
- "data",
135
- "session_id",
136
- "stop_reason",
137
- "usage",
138
- "total_cost_usd",
139
- ]:
140
- if hasattr(message, attr):
141
- message_dict[attr] = getattr(message, attr)
142
-
143
- converted_message = model_class.model_validate(message_dict)
144
-
145
- logger.debug(
146
- "claude_sdk_message_converted_successfully",
147
- original_type=type(message).__name__,
148
- converted_type=type(converted_message).__name__,
149
- message_count=message_count,
150
- request_id=request_id,
385
+ async def _query_with_session_pool(
386
+ self,
387
+ message: SDKMessage,
388
+ options: ClaudeCodeOptions,
389
+ request_id: str | None = None,
390
+ session_id: str | None = None,
391
+ ) -> AsyncIterator[
392
+ sdk_models.UserMessage
393
+ | sdk_models.AssistantMessage
394
+ | sdk_models.SystemMessage
395
+ | sdk_models.ResultMessage
396
+ ]:
397
+ """Execute query using session-aware pooled connection."""
398
+ async with timed_operation("claude_sdk_query_session_pool", request_id) as op:
399
+ try:
400
+ if not session_id:
401
+ raise ClaudeSDKError("Session ID required for session pool")
402
+
403
+ if not self._session_manager:
404
+ raise ClaudeSDKError("No session manager available")
405
+
406
+ # Enable continue conversation for session pool
407
+ # so conversation is possible to resume based on session_id
408
+ options.continue_conversation = True
409
+
410
+ session_client = await self._session_manager.get_session_client(
411
+ session_id, options
412
+ )
413
+
414
+ async with session_client.lock: # Prevent concurrent access
415
+ session_client.update_usage()
416
+
417
+ # Ensure client is connected
418
+ if not session_client.claude_client:
419
+ logger.error(
420
+ "session_client_not_connected",
421
+ session_id=session_id,
422
+ status=session_client.status,
151
423
  )
152
- yield converted_message
153
- except Exception as e:
154
- logger.warning(
155
- "claude_sdk_message_conversion_failed",
156
- message_type=type(message).__name__,
157
- model_class=model_class.__name__,
158
- error=str(e),
424
+ raise ClaudeSDKError(
425
+ f"Session client not connected for session {session_id}"
159
426
  )
160
- # Skip invalid messages rather than crashing
161
- continue
162
427
 
163
- # Store final metrics
164
- op["message_count"] = message_count
165
- self._last_api_call_time_ms = op.get("duration_ms", 0.0)
428
+ # Mark session as having active stream
429
+ session_client.has_active_stream = True
430
+
431
+ # Create wrapped stream generator
432
+ async def stream_with_cleanup() -> AsyncIterator[
433
+ sdk_models.UserMessage
434
+ | sdk_models.AssistantMessage
435
+ | sdk_models.SystemMessage
436
+ | sdk_models.ResultMessage
437
+ ]:
438
+ stream_iterator = None
439
+ try:
440
+ message_count = 0
441
+ if not session_client.claude_client:
442
+ raise ClaudeSDKError("Session client not connected")
443
+
444
+ stream_iterator = self._execute_with_client(
445
+ session_client.claude_client,
446
+ message,
447
+ session_id,
448
+ request_id,
449
+ session_client=session_client,
450
+ )
166
451
 
167
- logger.debug(
168
- "claude_sdk_query_completed",
169
- message_count=message_count,
170
- duration_ms=op.get("duration_ms"),
171
- )
452
+ async for msg in stream_iterator:
453
+ message_count += 1
454
+ yield msg
455
+
456
+ op["message_count"] = message_count
457
+ op["session_id"] = session_id
458
+ self._last_api_call_time_ms = op.get("duration_ms", 0.0)
459
+
460
+ except GeneratorExit:
461
+ # Client disconnected - mark session for drain
462
+ logger.warning(
463
+ "claude_sdk_session_stream_interrupted",
464
+ session_id=session_id,
465
+ request_id=request_id,
466
+ message="Client disconnected, session will drain stream on next interrupt",
467
+ )
172
468
 
173
- except (CLINotFoundError, CLIConnectionError) as e:
469
+ # Just mark that stream needs draining
470
+ # The SessionClient.interrupt() will handle the actual draining
471
+ session_client.has_active_stream = True
472
+ raise
473
+ finally:
474
+ # Clean up if stream completed normally
475
+ if not session_client.has_active_stream:
476
+ session_client.has_active_stream = False
477
+
478
+ # Yield from the wrapped generator
479
+ async for msg in stream_with_cleanup():
480
+ yield msg
481
+
482
+ except StreamTimeoutError:
483
+ raise # Let service layer handle
484
+ except Exception as e:
174
485
  logger.error(
175
- "claude_sdk_connection_failed",
486
+ "claude_sdk_session_pool_query_error",
176
487
  error=str(e),
177
488
  error_type=type(e).__name__,
489
+ session_id=session_id,
490
+ exc_info=True,
491
+ )
492
+ # Fall back to direct query
493
+ logger.info(
494
+ "claude_sdk_fallback_to_direct_query", session_id=session_id
178
495
  )
179
- raise ServiceUnavailableError(
180
- f"Claude CLI not available: {str(e)}"
181
- ) from e
182
- except (ProcessError, CLIJSONDecodeError) as e:
496
+ async for msg in self._query(message, options, request_id, session_id):
497
+ yield msg
498
+
499
+ async def _wait_for_first_chunk(
500
+ self,
501
+ message_iterator: AsyncIterator[Any],
502
+ timeout_seconds: float = 5.0,
503
+ session_id: str | None = None,
504
+ request_id: str | None = None,
505
+ ) -> tuple[Any, AsyncIterator[Any]]:
506
+ """
507
+ Wait for the first chunk from an async iterator with timeout.
508
+
509
+ Args:
510
+ message_iterator: The async iterator to get messages from
511
+ timeout_seconds: Timeout in seconds (default 5.0)
512
+ session_id: Optional session ID for logging
513
+ request_id: Optional request ID for logging
514
+
515
+ Returns:
516
+ Tuple of (first_message, remaining_iterator)
517
+
518
+ Raises:
519
+ StreamTimeoutError: If no chunk is received within timeout
520
+ """
521
+ try:
522
+ # Wait for the first chunk with timeout - don't care about message type
523
+ logger.debug("waiting_for_first_chunk", timeout=timeout_seconds)
524
+ first_message = await asyncio.wait_for(
525
+ anext(message_iterator), timeout=timeout_seconds
526
+ )
527
+ return first_message, message_iterator
528
+ except TimeoutError:
529
+ # Check if session pool is enabled - if so, let it handle the timeout
530
+ has_session_pool = (
531
+ self._session_manager and await self._session_manager.has_session_pool()
532
+ )
533
+
534
+ if has_session_pool:
183
535
  logger.error(
184
- "claude_sdk_process_failed",
185
- error=str(e),
186
- error_type=type(e).__name__,
536
+ "first_chunk_timeout",
537
+ session_id=session_id,
538
+ request_id=request_id,
539
+ timeout=timeout_seconds,
540
+ message="No chunk received within timeout, session pool will handle cleanup",
187
541
  )
188
- raise ClaudeProxyError(
189
- message=f"Claude process error: {str(e)}",
190
- error_type="service_unavailable_error",
191
- status_code=503,
192
- ) from e
193
- except Exception as e:
542
+ else:
194
543
  logger.error(
195
- "claude_sdk_unexpected_error_occurred",
196
- error=str(e),
197
- error_type=type(e).__name__,
544
+ "first_chunk_timeout",
545
+ session_id=session_id,
546
+ request_id=request_id,
547
+ timeout=timeout_seconds,
548
+ message="No chunk received within timeout, interrupting session",
198
549
  )
199
- raise ClaudeProxyError(
200
- message=f"Unexpected error: {str(e)}",
201
- error_type="internal_server_error",
202
- status_code=500,
203
- ) from e
550
+ # Interrupt the session if we have a session_id and session manager (no session pool)
551
+ if session_id and self._session_manager:
552
+ try:
553
+ await self._session_manager.interrupt_session(session_id)
554
+ except Exception as e:
555
+ logger.error(
556
+ "failed_to_interrupt_stuck_session",
557
+ session_id=session_id,
558
+ error=str(e),
559
+ )
204
560
 
205
- def get_last_api_call_time_ms(self) -> float:
561
+ # Raise a custom exception with error details
562
+ raise StreamTimeoutError(
563
+ message=f"Stream timeout: No response received within {timeout_seconds} seconds. The command may not be supported or the session may be stuck.",
564
+ session_id=session_id or "unknown",
565
+ timeout_seconds=timeout_seconds,
566
+ ) from None
567
+
568
+ async def _process_message_stream(
569
+ self,
570
+ message_iterator: AsyncIterator[Any],
571
+ request_id: str | None = None,
572
+ session_id: str | None = None,
573
+ session_client: Any = None, # SessionClient for session pool
574
+ drain_mode: bool = False, # If True, consume but don't yield
575
+ ) -> AsyncIterator[
576
+ sdk_models.UserMessage
577
+ | sdk_models.AssistantMessage
578
+ | sdk_models.SystemMessage
579
+ | sdk_models.ResultMessage
580
+ ]:
206
581
  """
207
- Get the duration of the last Claude API call in milliseconds.
582
+ Process messages from an async iterator, converting them to Pydantic models.
583
+
584
+ Args:
585
+ message_iterator: The async iterator of SDK messages
586
+ request_id: Optional request ID for logging
587
+ session_id: Optional session ID for logging
588
+ session_client: Optional session context for session pool operations
589
+ drain_mode: If True, consume messages without yielding (for cleanup)
590
+
591
+ Yields:
592
+ Converted Pydantic model messages (unless drain_mode is True)
593
+ """
594
+ async for sdk_msg in message_iterator:
595
+ # Find matching type and convert
596
+ for sdk_type, model_type in self.MESSAGE_TYPE_MAP.items():
597
+ if isinstance(sdk_msg, sdk_type):
598
+ try:
599
+ converted_message = cast(
600
+ sdk_models.UserMessage
601
+ | sdk_models.AssistantMessage
602
+ | sdk_models.SystemMessage
603
+ | sdk_models.ResultMessage,
604
+ self._convert_message(sdk_msg, model_type),
605
+ )
606
+
607
+ # Special handling for ResultMessage
608
+ if session_client and isinstance(
609
+ converted_message, sdk_models.ResultMessage
610
+ ):
611
+ session_client.sdk_session_id = converted_message.session_id
612
+
613
+ # Only yield if not in drain mode
614
+ if not drain_mode:
615
+ yield converted_message
616
+ else:
617
+ logger.debug(
618
+ "claude_sdk_draining_message",
619
+ message_type=type(converted_message).__name__,
620
+ request_id=request_id,
621
+ session_id=session_id,
622
+ )
623
+ except Exception as e:
624
+ logger.warning(
625
+ "claude_sdk_message_conversion_failed",
626
+ message_type=type(sdk_msg).__name__,
627
+ error=str(e),
628
+ request_id=request_id,
629
+ session_id=session_id,
630
+ )
631
+ break
632
+ else:
633
+ # No matching type found
634
+ logger.warning(
635
+ "claude_sdk_unknown_message_type",
636
+ message_type=type(sdk_msg).__name__,
637
+ request_id=request_id,
638
+ session_id=session_id,
639
+ )
640
+
641
+ async def _create_drain_task(
642
+ self,
643
+ message_iterator: AsyncIterator[Any],
644
+ session_client: Any,
645
+ request_id: str | None = None,
646
+ session_id: str | None = None,
647
+ ) -> asyncio.Task[None]:
648
+ """Create a background task to drain remaining messages from stream.
649
+
650
+ Args:
651
+ message_iterator: The message iterator to drain
652
+ session_client: Session client to update stream status
653
+ request_id: Optional request ID for logging
654
+ session_id: Optional session ID for logging
208
655
 
209
656
  Returns:
210
- Duration in milliseconds, or 0.0 if no call has been made yet
657
+ Task that completes when stream is drained
211
658
  """
212
- return self._last_api_call_time_ms
659
+
660
+ async def drain_stream() -> None:
661
+ try:
662
+ logger.info(
663
+ "claude_sdk_starting_stream_drain",
664
+ session_id=session_id,
665
+ request_id=request_id,
666
+ )
667
+
668
+ message_count = 0
669
+ async for _ in self._process_message_stream(
670
+ message_iterator,
671
+ request_id=request_id,
672
+ session_id=session_id,
673
+ session_client=session_client,
674
+ drain_mode=True,
675
+ ):
676
+ message_count += 1
677
+
678
+ logger.info(
679
+ "claude_sdk_stream_drained",
680
+ session_id=session_id,
681
+ request_id=request_id,
682
+ drained_messages=message_count,
683
+ )
684
+ except Exception as e:
685
+ logger.error(
686
+ "claude_sdk_stream_drain_error",
687
+ session_id=session_id,
688
+ request_id=request_id,
689
+ error=str(e),
690
+ error_type=type(e).__name__,
691
+ )
692
+ finally:
693
+ if session_client:
694
+ session_client.has_active_stream = False
695
+ session_client.active_stream_task = None
696
+
697
+ return asyncio.create_task(drain_stream())
698
+
699
+ def _convert_message(self, message: Any, model_class: type[T]) -> T:
700
+ """Convert SDK message to Pydantic model."""
701
+ # Try standard object attribute extraction first
702
+ if hasattr(message, "__dict__"):
703
+ return model_class.model_validate(vars(message))
704
+
705
+ # Handle dataclass objects
706
+ if hasattr(message, "__dataclass_fields__"):
707
+ message_dict = {
708
+ field: getattr(message, field) for field in message.__dataclass_fields__
709
+ }
710
+ return model_class.model_validate(message_dict)
711
+
712
+ # Fallback: extract common attributes
713
+ message_dict = {}
714
+ for attr in [
715
+ "content",
716
+ "subtype",
717
+ "data",
718
+ "session_id",
719
+ "stop_reason",
720
+ "usage",
721
+ "total_cost_usd",
722
+ ]:
723
+ if hasattr(message, attr):
724
+ message_dict[attr] = getattr(message, attr)
725
+
726
+ return model_class.model_validate(message_dict)
213
727
 
214
728
  async def validate_health(self) -> bool:
215
729
  """
@@ -238,15 +752,31 @@ class ClaudeSDKClient:
238
752
  )
239
753
  return False
240
754
 
755
+ async def interrupt_session(self, session_id: str) -> bool:
756
+ """Interrupt a specific session due to client disconnection.
757
+
758
+ Args:
759
+ session_id: The session ID to interrupt
760
+
761
+ Returns:
762
+ True if session was found and interrupted, False otherwise
763
+ """
764
+ logger.debug("sdk_client_interrupt_session_started", session_id=session_id)
765
+ if self._session_manager:
766
+ logger.info(
767
+ "client_interrupt_session_requested",
768
+ session_id=session_id,
769
+ has_session_manager=True,
770
+ )
771
+ return await self._session_manager.interrupt_session(session_id)
772
+ else:
773
+ logger.warning(
774
+ "client_interrupt_session_no_session_manager",
775
+ session_id=session_id,
776
+ )
777
+ return False
778
+
241
779
  async def close(self) -> None:
242
780
  """Close the client and cleanup resources."""
243
781
  # Claude Code SDK doesn't require explicit cleanup
244
782
  pass
245
-
246
- async def __aenter__(self) -> "ClaudeSDKClient":
247
- """Async context manager entry."""
248
- return self
249
-
250
- async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
251
- """Async context manager exit."""
252
- await self.close()