ccproxy-api 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. ccproxy/_version.py +2 -2
  2. ccproxy/adapters/openai/adapter.py +1 -1
  3. ccproxy/adapters/openai/streaming.py +1 -0
  4. ccproxy/api/app.py +134 -224
  5. ccproxy/api/dependencies.py +22 -2
  6. ccproxy/api/middleware/errors.py +27 -3
  7. ccproxy/api/middleware/logging.py +4 -0
  8. ccproxy/api/responses.py +6 -1
  9. ccproxy/api/routes/claude.py +222 -17
  10. ccproxy/api/routes/proxy.py +25 -6
  11. ccproxy/api/services/permission_service.py +2 -2
  12. ccproxy/claude_sdk/__init__.py +4 -8
  13. ccproxy/claude_sdk/client.py +661 -131
  14. ccproxy/claude_sdk/exceptions.py +16 -0
  15. ccproxy/claude_sdk/manager.py +219 -0
  16. ccproxy/claude_sdk/message_queue.py +342 -0
  17. ccproxy/claude_sdk/options.py +5 -0
  18. ccproxy/claude_sdk/session_client.py +546 -0
  19. ccproxy/claude_sdk/session_pool.py +550 -0
  20. ccproxy/claude_sdk/stream_handle.py +538 -0
  21. ccproxy/claude_sdk/stream_worker.py +392 -0
  22. ccproxy/claude_sdk/streaming.py +53 -11
  23. ccproxy/cli/commands/serve.py +96 -0
  24. ccproxy/cli/options/claude_options.py +47 -0
  25. ccproxy/config/__init__.py +0 -3
  26. ccproxy/config/claude.py +171 -23
  27. ccproxy/config/discovery.py +10 -1
  28. ccproxy/config/scheduler.py +4 -4
  29. ccproxy/config/settings.py +19 -1
  30. ccproxy/core/http_transformers.py +305 -73
  31. ccproxy/core/logging.py +108 -12
  32. ccproxy/core/transformers.py +5 -0
  33. ccproxy/models/claude_sdk.py +57 -0
  34. ccproxy/models/detection.py +126 -0
  35. ccproxy/observability/access_logger.py +72 -14
  36. ccproxy/observability/metrics.py +151 -0
  37. ccproxy/observability/storage/duckdb_simple.py +12 -0
  38. ccproxy/observability/storage/models.py +16 -0
  39. ccproxy/observability/streaming_response.py +107 -0
  40. ccproxy/scheduler/manager.py +31 -6
  41. ccproxy/scheduler/tasks.py +122 -0
  42. ccproxy/services/claude_detection_service.py +269 -0
  43. ccproxy/services/claude_sdk_service.py +333 -130
  44. ccproxy/services/proxy_service.py +91 -200
  45. ccproxy/utils/__init__.py +9 -1
  46. ccproxy/utils/disconnection_monitor.py +83 -0
  47. ccproxy/utils/id_generator.py +12 -0
  48. ccproxy/utils/startup_helpers.py +408 -0
  49. {ccproxy_api-0.1.4.dist-info → ccproxy_api-0.1.5.dist-info}/METADATA +29 -2
  50. {ccproxy_api-0.1.4.dist-info → ccproxy_api-0.1.5.dist-info}/RECORD +53 -41
  51. ccproxy/config/loader.py +0 -105
  52. {ccproxy_api-0.1.4.dist-info → ccproxy_api-0.1.5.dist-info}/WHEEL +0 -0
  53. {ccproxy_api-0.1.4.dist-info → ccproxy_api-0.1.5.dist-info}/entry_points.txt +0 -0
  54. {ccproxy_api-0.1.4.dist-info → ccproxy_api-0.1.5.dist-info}/licenses/LICENSE +0 -0
@@ -338,6 +338,59 @@ SDKContentBlock = Annotated[
338
338
  ExtendedContentBlock = SDKContentBlock
339
339
 
340
340
 
341
+ # SDK Query Message Types
342
+ class SDKMessageContent(BaseModel):
343
+ """Content structure for SDK query messages."""
344
+
345
+ role: Literal["user"] = "user"
346
+ content: str = Field(..., description="Message text content")
347
+
348
+ model_config = ConfigDict(extra="forbid")
349
+
350
+
351
+ class SDKMessage(BaseModel):
352
+ """Message format used to send queries over the Claude SDK.
353
+
354
+ This represents the internal message structure expected by the
355
+ Claude Code SDK client for query operations.
356
+ """
357
+
358
+ type: Literal["user"] = "user"
359
+ message: SDKMessageContent = Field(
360
+ ..., description="Message content with role and text"
361
+ )
362
+ parent_tool_use_id: str | None = Field(
363
+ None, description="Optional parent tool use ID"
364
+ )
365
+ session_id: str | None = Field(
366
+ None, description="Optional session ID for conversation continuity"
367
+ )
368
+
369
+ model_config = ConfigDict(extra="forbid")
370
+
371
+
372
+ def create_sdk_message(
373
+ content: str,
374
+ session_id: str | None = None,
375
+ parent_tool_use_id: str | None = None,
376
+ ) -> SDKMessage:
377
+ """Create an SDKMessage instance for sending queries to Claude SDK.
378
+
379
+ Args:
380
+ content: The text content to send to Claude
381
+ session_id: Optional session ID for conversation continuity
382
+ parent_tool_use_id: Optional parent tool use ID
383
+
384
+ Returns:
385
+ SDKMessage instance ready to send to Claude SDK
386
+ """
387
+ return SDKMessage(
388
+ message=SDKMessageContent(content=content),
389
+ session_id=session_id,
390
+ parent_tool_use_id=parent_tool_use_id,
391
+ )
392
+
393
+
341
394
  # Conversion Functions
342
395
  def convert_sdk_text_block(text_content: str) -> TextBlock:
343
396
  """Convert raw text content to TextBlock model."""
@@ -404,6 +457,10 @@ __all__ = [
404
457
  "AssistantMessage",
405
458
  "SystemMessage",
406
459
  "ResultMessage",
460
+ # SDK Query Messages
461
+ "SDKMessageContent",
462
+ "SDKMessage",
463
+ "create_sdk_message",
407
464
  # Custom content blocks
408
465
  "SDKMessageMode",
409
466
  "ToolUseSDKBlock",
@@ -0,0 +1,126 @@
1
+ """Detection models for Claude Code CLI headers and system prompt extraction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import UTC, datetime
6
+ from typing import Annotated, Any
7
+
8
+ from pydantic import BaseModel, ConfigDict, Field
9
+
10
+
11
+ class ClaudeCodeHeaders(BaseModel):
12
+ """Pydantic model for Claude CLI headers extraction with field aliases."""
13
+
14
+ anthropic_beta: str = Field(
15
+ alias="anthropic-beta",
16
+ description="Anthropic beta features",
17
+ default="claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14",
18
+ )
19
+ anthropic_version: str = Field(
20
+ alias="anthropic-version",
21
+ description="Anthropic API version",
22
+ default="2023-06-01",
23
+ )
24
+ anthropic_dangerous_direct_browser_access: str = Field(
25
+ alias="anthropic-dangerous-direct-browser-access",
26
+ description="Browser access flag",
27
+ default="true",
28
+ )
29
+ x_app: str = Field(
30
+ alias="x-app", description="Application identifier", default="cli"
31
+ )
32
+ user_agent: str = Field(
33
+ alias="user-agent",
34
+ description="User agent string",
35
+ default="claude-cli/1.0.60 (external, cli)",
36
+ )
37
+ x_stainless_lang: str = Field(
38
+ alias="x-stainless-lang", description="SDK language", default="js"
39
+ )
40
+ x_stainless_retry_count: str = Field(
41
+ alias="x-stainless-retry-count", description="Retry count", default="0"
42
+ )
43
+ x_stainless_timeout: str = Field(
44
+ alias="x-stainless-timeout", description="Request timeout", default="60"
45
+ )
46
+ x_stainless_package_version: str = Field(
47
+ alias="x-stainless-package-version",
48
+ description="Package version",
49
+ default="0.55.1",
50
+ )
51
+ x_stainless_os: str = Field(
52
+ alias="x-stainless-os", description="Operating system", default="Linux"
53
+ )
54
+ x_stainless_arch: str = Field(
55
+ alias="x-stainless-arch", description="Architecture", default="x64"
56
+ )
57
+ x_stainless_runtime: str = Field(
58
+ alias="x-stainless-runtime", description="Runtime", default="node"
59
+ )
60
+ x_stainless_runtime_version: str = Field(
61
+ alias="x-stainless-runtime-version",
62
+ description="Runtime version",
63
+ default="v24.3.0",
64
+ )
65
+
66
+ model_config = ConfigDict(extra="ignore", populate_by_name=True)
67
+
68
+ def to_headers_dict(self) -> dict[str, str]:
69
+ """Convert to headers dictionary for HTTP forwarding with proper case."""
70
+ headers = {}
71
+
72
+ # Map field names to proper HTTP header names
73
+ header_mapping = {
74
+ "anthropic_beta": "anthropic-beta",
75
+ "anthropic_version": "anthropic-version",
76
+ "anthropic_dangerous_direct_browser_access": "anthropic-dangerous-direct-browser-access",
77
+ "x_app": "x-app",
78
+ "user_agent": "User-Agent",
79
+ "x_stainless_lang": "X-Stainless-Lang",
80
+ "x_stainless_retry_count": "X-Stainless-Retry-Count",
81
+ "x_stainless_timeout": "X-Stainless-Timeout",
82
+ "x_stainless_package_version": "X-Stainless-Package-Version",
83
+ "x_stainless_os": "X-Stainless-OS",
84
+ "x_stainless_arch": "X-Stainless-Arch",
85
+ "x_stainless_runtime": "X-Stainless-Runtime",
86
+ "x_stainless_runtime_version": "X-Stainless-Runtime-Version",
87
+ }
88
+
89
+ for field_name, header_name in header_mapping.items():
90
+ value = getattr(self, field_name, None)
91
+ if value is not None:
92
+ headers[header_name] = value
93
+
94
+ return headers
95
+
96
+
97
+ class SystemPromptData(BaseModel):
98
+ """Extracted system prompt information."""
99
+
100
+ system_field: Annotated[
101
+ str | list[dict[str, Any]],
102
+ Field(
103
+ description="Complete system field as detected from Claude CLI, preserving exact structure including type, text, and cache_control"
104
+ ),
105
+ ]
106
+
107
+ model_config = ConfigDict(extra="forbid")
108
+
109
+
110
+ class ClaudeCacheData(BaseModel):
111
+ """Cached Claude CLI detection data with version tracking."""
112
+
113
+ claude_version: Annotated[str, Field(description="Claude CLI version")]
114
+ headers: Annotated[ClaudeCodeHeaders, Field(description="Extracted headers")]
115
+ system_prompt: Annotated[
116
+ SystemPromptData, Field(description="Extracted system prompt")
117
+ ]
118
+ cached_at: Annotated[
119
+ datetime,
120
+ Field(
121
+ description="Cache timestamp",
122
+ default_factory=lambda: datetime.now(UTC),
123
+ ),
124
+ ] = None # type: ignore # Pydantic handles this via default_factory
125
+
126
+ model_config = ConfigDict(extra="forbid")
@@ -63,20 +63,31 @@ async def log_request_access(
63
63
  path = path or ctx_metadata.get("path")
64
64
  status_code = status_code or ctx_metadata.get("status_code")
65
65
 
66
- # Prepare comprehensive log data
66
+ # Prepare basic log data (always included)
67
67
  log_data = {
68
68
  "request_id": context.request_id,
69
69
  "method": method,
70
70
  "path": path,
71
71
  "query": query,
72
- "status_code": status_code,
73
72
  "client_ip": client_ip,
74
73
  "user_agent": user_agent,
75
- "duration_ms": context.duration_ms,
76
- "duration_seconds": context.duration_seconds,
77
- "error_message": error_message,
78
74
  }
79
75
 
76
+ # Add response-specific fields (only for completed requests)
77
+ is_streaming = ctx_metadata.get("streaming", False)
78
+ is_streaming_complete = ctx_metadata.get("event_type", "") == "streaming_complete"
79
+
80
+ # Include response fields only if this is not a streaming start
81
+ if not is_streaming or is_streaming_complete or ctx_metadata.get("error"):
82
+ log_data.update(
83
+ {
84
+ "status_code": status_code,
85
+ "duration_ms": context.duration_ms,
86
+ "duration_seconds": context.duration_seconds,
87
+ "error_message": error_message,
88
+ }
89
+ )
90
+
80
91
  # Add token and cost metrics if available
81
92
  token_fields = [
82
93
  "tokens_input",
@@ -85,6 +96,7 @@ async def log_request_access(
85
96
  "cache_write_tokens",
86
97
  "cost_usd",
87
98
  "cost_sdk_usd",
99
+ "num_turns",
88
100
  ]
89
101
 
90
102
  for field in token_fields:
@@ -93,18 +105,50 @@ async def log_request_access(
93
105
  log_data[field] = value
94
106
 
95
107
  # Add service and endpoint info
96
- service_fields = [
97
- "endpoint",
98
- "model",
99
- "streaming",
100
- "service_type",
101
- ]
108
+ service_fields = ["endpoint", "model", "streaming", "service_type", "headers"]
102
109
 
103
110
  for field in service_fields:
104
111
  value = ctx_metadata.get(field)
105
112
  if value is not None:
106
113
  log_data[field] = value
107
114
 
115
+ # Add session context metadata if available
116
+ session_fields = [
117
+ "session_id",
118
+ "session_type", # "session_pool" or "direct"
119
+ "session_status", # active, idle, connecting, etc.
120
+ "session_age_seconds", # how long session has been alive
121
+ "session_message_count", # number of messages in session
122
+ "session_pool_enabled", # whether session pooling is enabled
123
+ "session_idle_seconds", # how long since last activity
124
+ "session_error_count", # number of errors in this session
125
+ "session_is_new", # whether this is a newly created session
126
+ ]
127
+
128
+ for field in session_fields:
129
+ value = ctx_metadata.get(field)
130
+ if value is not None:
131
+ log_data[field] = value
132
+
133
+ # Add rate limit headers if available
134
+ rate_limit_fields = [
135
+ "x-ratelimit-limit",
136
+ "x-ratelimit-remaining",
137
+ "x-ratelimit-reset",
138
+ "anthropic-ratelimit-requests-limit",
139
+ "anthropic-ratelimit-requests-remaining",
140
+ "anthropic-ratelimit-requests-reset",
141
+ "anthropic-ratelimit-tokens-limit",
142
+ "anthropic-ratelimit-tokens-remaining",
143
+ "anthropic-ratelimit-tokens-reset",
144
+ "anthropic_request_id",
145
+ ]
146
+
147
+ for field in rate_limit_fields:
148
+ value = ctx_metadata.get(field)
149
+ if value is not None:
150
+ log_data[field] = value
151
+
108
152
  # Add any additional metadata provided
109
153
  log_data.update(additional_metadata)
110
154
 
@@ -112,15 +156,18 @@ async def log_request_access(
112
156
  log_data = {k: v for k, v in log_data.items() if v is not None}
113
157
 
114
158
  logger = context.logger.bind(**log_data)
115
- if not log_data.get("streaming", False):
159
+
160
+ if context.metadata.get("error"):
161
+ logger.warn("access_log", exc_info=context.metadata.get("error"))
162
+ elif not is_streaming:
116
163
  # Log as access_log event (structured logging)
117
164
  logger.info("access_log")
118
- elif log_data.get("event_type", "") == "streaming_complete":
165
+ elif is_streaming_complete:
119
166
  logger.info("access_log")
120
167
  else:
121
168
  # if streaming is true, and not streaming_complete log as debug
122
169
  # real access_log will come later
123
- logger.debug("access_log")
170
+ logger.info("access_log_streaming_start")
124
171
 
125
172
  # Store in DuckDB if available
126
173
  await _store_access_log(log_data, storage)
@@ -258,6 +305,17 @@ async def _store_access_log(
258
305
  "cache_write_tokens": log_data.get("cache_write_tokens", 0),
259
306
  "cost_usd": log_data.get("cost_usd", 0.0),
260
307
  "cost_sdk_usd": log_data.get("cost_sdk_usd", 0.0),
308
+ "num_turns": log_data.get("num_turns", 0),
309
+ # Session context metadata
310
+ "session_type": log_data.get("session_type", ""),
311
+ "session_status": log_data.get("session_status", ""),
312
+ "session_age_seconds": log_data.get("session_age_seconds", 0.0),
313
+ "session_message_count": log_data.get("session_message_count", 0),
314
+ "session_client_id": log_data.get("session_client_id", ""),
315
+ "session_pool_enabled": log_data.get("session_pool_enabled", False),
316
+ "session_idle_seconds": log_data.get("session_idle_seconds", 0.0),
317
+ "session_error_count": log_data.get("session_error_count", 0),
318
+ "session_is_new": log_data.get("session_is_new", True),
261
319
  }
262
320
 
263
321
  # Store asynchronously using queue-based DuckDB (prevents deadlocks)
@@ -205,6 +205,62 @@ class PrometheusMetrics:
205
205
  registry=self.registry,
206
206
  )
207
207
 
208
+ # Claude SDK Pool metrics
209
+ self.pool_clients_total = Gauge(
210
+ f"{self.namespace}_pool_clients_total",
211
+ "Total number of clients in the pool",
212
+ registry=self.registry,
213
+ )
214
+
215
+ self.pool_clients_available = Gauge(
216
+ f"{self.namespace}_pool_clients_available",
217
+ "Number of available clients in the pool",
218
+ registry=self.registry,
219
+ )
220
+
221
+ self.pool_clients_active = Gauge(
222
+ f"{self.namespace}_pool_clients_active",
223
+ "Number of active clients currently processing requests",
224
+ registry=self.registry,
225
+ )
226
+
227
+ self.pool_connections_created_total = Counter(
228
+ f"{self.namespace}_pool_connections_created_total",
229
+ "Total number of pool connections created",
230
+ registry=self.registry,
231
+ )
232
+
233
+ self.pool_connections_closed_total = Counter(
234
+ f"{self.namespace}_pool_connections_closed_total",
235
+ "Total number of pool connections closed",
236
+ registry=self.registry,
237
+ )
238
+
239
+ self.pool_acquisitions_total = Counter(
240
+ f"{self.namespace}_pool_acquisitions_total",
241
+ "Total number of client acquisitions from pool",
242
+ registry=self.registry,
243
+ )
244
+
245
+ self.pool_releases_total = Counter(
246
+ f"{self.namespace}_pool_releases_total",
247
+ "Total number of client releases to pool",
248
+ registry=self.registry,
249
+ )
250
+
251
+ self.pool_health_check_failures_total = Counter(
252
+ f"{self.namespace}_pool_health_check_failures_total",
253
+ "Total number of pool health check failures",
254
+ registry=self.registry,
255
+ )
256
+
257
+ self.pool_acquisition_duration = Histogram(
258
+ f"{self.namespace}_pool_acquisition_duration_seconds",
259
+ "Time taken to acquire a client from the pool",
260
+ buckets=[0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0],
261
+ registry=self.registry,
262
+ )
263
+
208
264
  # Set initial system info
209
265
  try:
210
266
  from ccproxy import __version__
@@ -468,6 +524,101 @@ class PrometheusMetrics:
468
524
  and self._pushgateway_client.is_enabled()
469
525
  )
470
526
 
527
+ # Claude SDK Pool metrics methods
528
+
529
+ def update_pool_gauges(
530
+ self,
531
+ total_clients: int,
532
+ available_clients: int,
533
+ active_clients: int,
534
+ ) -> None:
535
+ """
536
+ Update pool gauge metrics (current state).
537
+
538
+ Args:
539
+ total_clients: Total number of clients in pool
540
+ available_clients: Number of available clients
541
+ active_clients: Number of active clients
542
+ """
543
+ if not self._enabled:
544
+ return
545
+
546
+ # Update gauges
547
+ self.pool_clients_total.set(total_clients)
548
+ self.pool_clients_available.set(available_clients)
549
+ self.pool_clients_active.set(active_clients)
550
+
551
+ # Note: Counters are managed directly by the pool operations
552
+ # This method only updates the current gauges
553
+
554
+ def record_pool_acquisition_time(self, duration_seconds: float) -> None:
555
+ """
556
+ Record the time taken to acquire a client from the pool.
557
+
558
+ Args:
559
+ duration_seconds: Time in seconds to acquire client
560
+ """
561
+ if not self._enabled:
562
+ return
563
+
564
+ self.pool_acquisition_duration.observe(duration_seconds)
565
+
566
+ def inc_pool_connections_created(self) -> None:
567
+ """Increment the pool connections created counter."""
568
+ if not self._enabled:
569
+ return
570
+
571
+ self.pool_connections_created_total.inc()
572
+
573
+ def inc_pool_connections_closed(self) -> None:
574
+ """Increment the pool connections closed counter."""
575
+ if not self._enabled:
576
+ return
577
+
578
+ self.pool_connections_closed_total.inc()
579
+
580
+ def inc_pool_acquisitions(self) -> None:
581
+ """Increment the pool acquisitions counter."""
582
+ if not self._enabled:
583
+ return
584
+
585
+ self.pool_acquisitions_total.inc()
586
+
587
+ def inc_pool_releases(self) -> None:
588
+ """Increment the pool releases counter."""
589
+ if not self._enabled:
590
+ return
591
+
592
+ self.pool_releases_total.inc()
593
+
594
+ def inc_pool_health_check_failures(self) -> None:
595
+ """Increment the pool health check failures counter."""
596
+ if not self._enabled:
597
+ return
598
+
599
+ self.pool_health_check_failures_total.inc()
600
+
601
+ def set_pool_clients_total(self, count: int) -> None:
602
+ """Set the total number of clients in the pool."""
603
+ if not self._enabled:
604
+ return
605
+
606
+ self.pool_clients_total.set(count)
607
+
608
+ def set_pool_clients_available(self, count: int) -> None:
609
+ """Set the number of available clients in the pool."""
610
+ if not self._enabled:
611
+ return
612
+
613
+ self.pool_clients_available.set(count)
614
+
615
+ def set_pool_clients_active(self, count: int) -> None:
616
+ """Set the number of active clients in the pool."""
617
+ if not self._enabled:
618
+ return
619
+
620
+ self.pool_clients_active.set(count)
621
+
471
622
 
472
623
  # Global metrics instance
473
624
  _global_metrics: PrometheusMetrics | None = None
@@ -60,6 +60,18 @@ class AccessLogPayload(TypedDict, total=False):
60
60
  cache_write_tokens: int
61
61
  cost_usd: float
62
62
  cost_sdk_usd: float
63
+ num_turns: int # number of conversation turns
64
+
65
+ # Session context metadata
66
+ session_type: str # "session_pool" or "direct"
67
+ session_status: str # active, idle, connecting, etc.
68
+ session_age_seconds: float # how long session has been alive
69
+ session_message_count: int # number of messages in session
70
+ session_client_id: str # unique session client identifier
71
+ session_pool_enabled: bool # whether session pooling is enabled
72
+ session_idle_seconds: float # how long since last activity
73
+ session_error_count: int # number of errors in this session
74
+ session_is_new: bool # whether this is a newly created session
63
75
 
64
76
 
65
77
  class SimpleDuckDBStorage:
@@ -44,6 +44,22 @@ class AccessLog(SQLModel, table=True):
44
44
  cache_write_tokens: int = Field(default=0)
45
45
  cost_usd: float = Field(default=0.0)
46
46
  cost_sdk_usd: float = Field(default=0.0)
47
+ num_turns: int = Field(default=0) # number of conversation turns
48
+
49
+ # Session context metadata
50
+ session_type: str = Field(default="") # "session_pool" or "direct"
51
+ session_status: str = Field(default="") # active, idle, connecting, etc.
52
+ session_age_seconds: float = Field(default=0.0) # how long session has been alive
53
+ session_message_count: int = Field(default=0) # number of messages in session
54
+ session_client_id: str = Field(default="") # unique session client identifier
55
+ session_pool_enabled: bool = Field(
56
+ default=False
57
+ ) # whether session pooling is enabled
58
+ session_idle_seconds: float = Field(default=0.0) # how long since last activity
59
+ session_error_count: int = Field(default=0) # number of errors in this session
60
+ session_is_new: bool = Field(
61
+ default=True
62
+ ) # whether this is a newly created session
47
63
 
48
64
  class Config:
49
65
  """SQLModel configuration."""
@@ -0,0 +1,107 @@
1
+ """FastAPI StreamingResponse with automatic access logging on completion.
2
+
3
+ This module provides a reusable StreamingResponseWithLogging class that wraps
4
+ any async generator and handles access logging when the stream completes,
5
+ eliminating code duplication between different streaming endpoints.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from collections.abc import AsyncGenerator, AsyncIterator
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ import structlog
14
+ from fastapi.responses import StreamingResponse
15
+
16
+ from ccproxy.observability.access_logger import log_request_access
17
+
18
+
19
+ if TYPE_CHECKING:
20
+ from ccproxy.observability.context import RequestContext
21
+ from ccproxy.observability.metrics import PrometheusMetrics
22
+
23
+ logger = structlog.get_logger(__name__)
24
+
25
+
26
+ class StreamingResponseWithLogging(StreamingResponse):
27
+ """FastAPI StreamingResponse that triggers access logging on completion.
28
+
29
+ This class wraps a streaming response generator to automatically trigger
30
+ access logging when the stream completes (either successfully or with an error).
31
+ This eliminates the need for manual access logging in individual stream processors.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ content: AsyncGenerator[bytes, None] | AsyncIterator[bytes],
37
+ request_context: RequestContext,
38
+ metrics: PrometheusMetrics | None = None,
39
+ status_code: int = 200,
40
+ **kwargs: Any,
41
+ ) -> None:
42
+ """Initialize streaming response with logging capability.
43
+
44
+ Args:
45
+ content: The async generator producing streaming content
46
+ request_context: The request context for access logging
47
+ metrics: Optional PrometheusMetrics instance for recording metrics
48
+ status_code: HTTP status code for the response
49
+ **kwargs: Additional arguments passed to StreamingResponse
50
+ """
51
+ # Wrap the content generator to add logging
52
+ logged_content = self._wrap_with_logging(
53
+ content, request_context, metrics, status_code
54
+ )
55
+ super().__init__(logged_content, status_code=status_code, **kwargs)
56
+
57
+ async def _wrap_with_logging(
58
+ self,
59
+ content: AsyncGenerator[bytes, None] | AsyncIterator[bytes],
60
+ context: RequestContext,
61
+ metrics: PrometheusMetrics | None,
62
+ status_code: int,
63
+ ) -> AsyncGenerator[bytes, None]:
64
+ """Wrap content generator with access logging on completion.
65
+
66
+ Args:
67
+ content: The original content generator
68
+ context: Request context for logging
69
+ metrics: Optional metrics instance
70
+ status_code: HTTP status code
71
+
72
+ Yields:
73
+ bytes: Content chunks from the original generator
74
+ """
75
+ try:
76
+ # Stream all content from the original generator
77
+ async for chunk in content:
78
+ yield chunk
79
+ except GeneratorExit:
80
+ # Client disconnected - log this and re-raise to propagate to underlying generators
81
+ logger.info(
82
+ "streaming_response_client_disconnected",
83
+ request_id=context.request_id,
84
+ message="Client disconnected from streaming response, propagating GeneratorExit",
85
+ )
86
+ # CRITICAL: Re-raise GeneratorExit to propagate disconnect to create_listener()
87
+ raise
88
+ finally:
89
+ # Log access when stream completes (success or error)
90
+ try:
91
+ # Add streaming completion event type to context
92
+ context.add_metadata(event_type="streaming_complete")
93
+
94
+ # Check if status_code was updated in context metadata (e.g., due to error)
95
+ final_status_code = context.metadata.get("status_code", status_code)
96
+
97
+ await log_request_access(
98
+ context=context,
99
+ status_code=final_status_code,
100
+ metrics=metrics,
101
+ )
102
+ except Exception as e:
103
+ logger.warning(
104
+ "streaming_access_log_failed",
105
+ error=str(e),
106
+ request_id=context.request_id,
107
+ )