headroom-ai 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. headroom/__init__.py +212 -0
  2. headroom/cache/__init__.py +76 -0
  3. headroom/cache/anthropic.py +517 -0
  4. headroom/cache/base.py +342 -0
  5. headroom/cache/compression_feedback.py +613 -0
  6. headroom/cache/compression_store.py +814 -0
  7. headroom/cache/dynamic_detector.py +1026 -0
  8. headroom/cache/google.py +884 -0
  9. headroom/cache/openai.py +584 -0
  10. headroom/cache/registry.py +175 -0
  11. headroom/cache/semantic.py +451 -0
  12. headroom/ccr/__init__.py +77 -0
  13. headroom/ccr/context_tracker.py +582 -0
  14. headroom/ccr/mcp_server.py +319 -0
  15. headroom/ccr/response_handler.py +772 -0
  16. headroom/ccr/tool_injection.py +415 -0
  17. headroom/cli.py +219 -0
  18. headroom/client.py +977 -0
  19. headroom/compression/__init__.py +42 -0
  20. headroom/compression/detector.py +424 -0
  21. headroom/compression/handlers/__init__.py +22 -0
  22. headroom/compression/handlers/base.py +219 -0
  23. headroom/compression/handlers/code_handler.py +506 -0
  24. headroom/compression/handlers/json_handler.py +418 -0
  25. headroom/compression/masks.py +345 -0
  26. headroom/compression/universal.py +465 -0
  27. headroom/config.py +474 -0
  28. headroom/exceptions.py +192 -0
  29. headroom/integrations/__init__.py +159 -0
  30. headroom/integrations/agno/__init__.py +53 -0
  31. headroom/integrations/agno/hooks.py +345 -0
  32. headroom/integrations/agno/model.py +625 -0
  33. headroom/integrations/agno/providers.py +154 -0
  34. headroom/integrations/langchain/__init__.py +106 -0
  35. headroom/integrations/langchain/agents.py +326 -0
  36. headroom/integrations/langchain/chat_model.py +1002 -0
  37. headroom/integrations/langchain/langsmith.py +324 -0
  38. headroom/integrations/langchain/memory.py +319 -0
  39. headroom/integrations/langchain/providers.py +200 -0
  40. headroom/integrations/langchain/retriever.py +371 -0
  41. headroom/integrations/langchain/streaming.py +341 -0
  42. headroom/integrations/mcp/__init__.py +37 -0
  43. headroom/integrations/mcp/server.py +533 -0
  44. headroom/memory/__init__.py +37 -0
  45. headroom/memory/extractor.py +390 -0
  46. headroom/memory/fast_store.py +621 -0
  47. headroom/memory/fast_wrapper.py +311 -0
  48. headroom/memory/inline_extractor.py +229 -0
  49. headroom/memory/store.py +434 -0
  50. headroom/memory/worker.py +260 -0
  51. headroom/memory/wrapper.py +321 -0
  52. headroom/models/__init__.py +39 -0
  53. headroom/models/registry.py +687 -0
  54. headroom/parser.py +293 -0
  55. headroom/pricing/__init__.py +51 -0
  56. headroom/pricing/anthropic_prices.py +81 -0
  57. headroom/pricing/litellm_pricing.py +113 -0
  58. headroom/pricing/openai_prices.py +91 -0
  59. headroom/pricing/registry.py +188 -0
  60. headroom/providers/__init__.py +61 -0
  61. headroom/providers/anthropic.py +621 -0
  62. headroom/providers/base.py +131 -0
  63. headroom/providers/cohere.py +362 -0
  64. headroom/providers/google.py +427 -0
  65. headroom/providers/litellm.py +297 -0
  66. headroom/providers/openai.py +566 -0
  67. headroom/providers/openai_compatible.py +521 -0
  68. headroom/proxy/__init__.py +19 -0
  69. headroom/proxy/server.py +2683 -0
  70. headroom/py.typed +0 -0
  71. headroom/relevance/__init__.py +124 -0
  72. headroom/relevance/base.py +106 -0
  73. headroom/relevance/bm25.py +255 -0
  74. headroom/relevance/embedding.py +255 -0
  75. headroom/relevance/hybrid.py +259 -0
  76. headroom/reporting/__init__.py +5 -0
  77. headroom/reporting/generator.py +549 -0
  78. headroom/storage/__init__.py +41 -0
  79. headroom/storage/base.py +125 -0
  80. headroom/storage/jsonl.py +220 -0
  81. headroom/storage/sqlite.py +289 -0
  82. headroom/telemetry/__init__.py +91 -0
  83. headroom/telemetry/collector.py +764 -0
  84. headroom/telemetry/models.py +880 -0
  85. headroom/telemetry/toin.py +1579 -0
  86. headroom/tokenizer.py +80 -0
  87. headroom/tokenizers/__init__.py +75 -0
  88. headroom/tokenizers/base.py +210 -0
  89. headroom/tokenizers/estimator.py +198 -0
  90. headroom/tokenizers/huggingface.py +317 -0
  91. headroom/tokenizers/mistral.py +245 -0
  92. headroom/tokenizers/registry.py +398 -0
  93. headroom/tokenizers/tiktoken_counter.py +248 -0
  94. headroom/transforms/__init__.py +106 -0
  95. headroom/transforms/base.py +57 -0
  96. headroom/transforms/cache_aligner.py +357 -0
  97. headroom/transforms/code_compressor.py +1313 -0
  98. headroom/transforms/content_detector.py +335 -0
  99. headroom/transforms/content_router.py +1158 -0
  100. headroom/transforms/llmlingua_compressor.py +638 -0
  101. headroom/transforms/log_compressor.py +529 -0
  102. headroom/transforms/pipeline.py +297 -0
  103. headroom/transforms/rolling_window.py +350 -0
  104. headroom/transforms/search_compressor.py +365 -0
  105. headroom/transforms/smart_crusher.py +2682 -0
  106. headroom/transforms/text_compressor.py +259 -0
  107. headroom/transforms/tool_crusher.py +338 -0
  108. headroom/utils.py +215 -0
  109. headroom_ai-0.2.13.dist-info/METADATA +315 -0
  110. headroom_ai-0.2.13.dist-info/RECORD +114 -0
  111. headroom_ai-0.2.13.dist-info/WHEEL +4 -0
  112. headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
  113. headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
  114. headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
@@ -0,0 +1,772 @@
1
+ """Response handling for CCR (Compress-Cache-Retrieve).
2
+
3
+ This module provides response interception and CCR tool call handling.
4
+ When the LLM calls headroom_retrieve, this handler:
5
+ 1. Detects the tool call in the response
6
+ 2. Retrieves content from the compression store
7
+ 3. Continues the conversation with the tool result
8
+ 4. Returns the final response to the client
9
+
10
+ This solves the critical gap where the proxy injects the tool but
11
+ can't handle the LLM's tool calls.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import logging
18
+ from collections.abc import Awaitable, Callable
19
+ from dataclasses import dataclass, field
20
+ from typing import Any
21
+
22
+ from ..cache.compression_store import get_compression_store
23
+ from .tool_injection import CCR_TOOL_NAME, parse_tool_call
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class CCRToolCall:
30
+ """Represents a detected CCR tool call."""
31
+
32
+ tool_call_id: str
33
+ hash_key: str
34
+ query: str | None = None
35
+
36
+
37
+ @dataclass
38
+ class CCRToolResult:
39
+ """Result of handling a CCR tool call."""
40
+
41
+ tool_call_id: str
42
+ content: str
43
+ success: bool
44
+ items_retrieved: int = 0
45
+ was_search: bool = False
46
+
47
+
48
+ @dataclass
49
+ class ResponseHandlerConfig:
50
+ """Configuration for CCR response handling."""
51
+
52
+ # Whether to handle CCR tool calls automatically
53
+ enabled: bool = True
54
+
55
+ # Maximum number of CCR retrieval rounds (prevent infinite loops)
56
+ max_retrieval_rounds: int = 3
57
+
58
+ # Whether to strip CCR tool calls from final response
59
+ strip_ccr_from_response: bool = True
60
+
61
+ # Timeout for continuation requests (ms)
62
+ continuation_timeout_ms: int = 120000
63
+
64
+
65
+ class CCRResponseHandler:
66
+ """Handles CCR tool calls in LLM responses.
67
+
68
+ This handler intercepts responses, detects CCR tool calls,
69
+ retrieves content, and continues the conversation until
70
+ the LLM produces a response without CCR tool calls.
71
+
72
+ Example flow:
73
+ 1. LLM response contains: tool_use(headroom_retrieve, hash=abc123)
74
+ 2. Handler detects this, retrieves original content
75
+ 3. Handler makes another API call with tool result
76
+ 4. LLM responds with actual content (no CCR tool call)
77
+ 5. Handler returns this final response
78
+
79
+ Usage:
80
+ handler = CCRResponseHandler(config)
81
+
82
+ # Check if response needs handling
83
+ if handler.has_ccr_tool_calls(response_json):
84
+ # Handle the tool calls
85
+ final_response = await handler.handle_response(
86
+ response_json,
87
+ messages,
88
+ tools,
89
+ api_call_fn,
90
+ provider="anthropic"
91
+ )
92
+ else:
93
+ final_response = response_json
94
+ """
95
+
96
+ def __init__(self, config: ResponseHandlerConfig | None = None):
97
+ self.config = config or ResponseHandlerConfig()
98
+ self._retrieval_count = 0
99
+
100
+ def has_ccr_tool_calls(
101
+ self,
102
+ response: dict[str, Any],
103
+ provider: str = "anthropic",
104
+ ) -> bool:
105
+ """Check if response contains CCR tool calls.
106
+
107
+ Args:
108
+ response: The API response JSON.
109
+ provider: The provider type.
110
+
111
+ Returns:
112
+ True if response contains headroom_retrieve tool calls.
113
+ """
114
+ tool_calls = self._extract_tool_calls(response, provider)
115
+ return any(
116
+ tc.get("name") == CCR_TOOL_NAME or tc.get("function", {}).get("name") == CCR_TOOL_NAME
117
+ for tc in tool_calls
118
+ )
119
+
120
+ def _extract_tool_calls(
121
+ self,
122
+ response: dict[str, Any],
123
+ provider: str,
124
+ ) -> list[dict[str, Any]]:
125
+ """Extract tool calls from response based on provider format."""
126
+ if provider == "anthropic":
127
+ # Anthropic format: content blocks with type=tool_use
128
+ content = response.get("content", [])
129
+ if isinstance(content, list):
130
+ return [block for block in content if block.get("type") == "tool_use"]
131
+ return []
132
+
133
+ elif provider == "openai":
134
+ # OpenAI format: message.tool_calls array
135
+ message = response.get("choices", [{}])[0].get("message", {})
136
+ tool_calls = message.get("tool_calls", [])
137
+ return list(tool_calls) if tool_calls else []
138
+
139
+ return []
140
+
141
+ def _parse_ccr_tool_calls(
142
+ self,
143
+ response: dict[str, Any],
144
+ provider: str,
145
+ ) -> tuple[list[CCRToolCall], list[dict[str, Any]]]:
146
+ """Parse CCR tool calls from response, separate from other tool calls.
147
+
148
+ Returns:
149
+ Tuple of (ccr_tool_calls, other_tool_calls)
150
+ """
151
+ all_tool_calls = self._extract_tool_calls(response, provider)
152
+
153
+ ccr_calls = []
154
+ other_calls = []
155
+
156
+ for tc in all_tool_calls:
157
+ hash_key, query = parse_tool_call(tc, provider)
158
+
159
+ if hash_key is not None:
160
+ # This is a CCR tool call
161
+ tool_call_id = tc.get("id", "")
162
+ ccr_calls.append(
163
+ CCRToolCall(
164
+ tool_call_id=tool_call_id,
165
+ hash_key=hash_key,
166
+ query=query,
167
+ )
168
+ )
169
+ else:
170
+ # Not a CCR tool call
171
+ other_calls.append(tc)
172
+
173
+ return ccr_calls, other_calls
174
+
175
+ def _execute_retrieval(self, ccr_call: CCRToolCall) -> CCRToolResult:
176
+ """Execute a CCR retrieval.
177
+
178
+ Args:
179
+ ccr_call: The CCR tool call to execute.
180
+
181
+ Returns:
182
+ CCRToolResult with the retrieved content.
183
+ """
184
+ store = get_compression_store()
185
+
186
+ try:
187
+ if ccr_call.query:
188
+ # Search within compressed content
189
+ results = store.search(ccr_call.hash_key, ccr_call.query)
190
+ content = json.dumps(
191
+ {
192
+ "hash": ccr_call.hash_key,
193
+ "query": ccr_call.query,
194
+ "results": results,
195
+ "count": len(results),
196
+ },
197
+ indent=2,
198
+ )
199
+ return CCRToolResult(
200
+ tool_call_id=ccr_call.tool_call_id,
201
+ content=content,
202
+ success=True,
203
+ items_retrieved=len(results),
204
+ was_search=True,
205
+ )
206
+ else:
207
+ # Full retrieval
208
+ entry = store.retrieve(ccr_call.hash_key)
209
+ if entry:
210
+ content = json.dumps(
211
+ {
212
+ "hash": ccr_call.hash_key,
213
+ "original_content": entry.original_content,
214
+ "original_item_count": entry.original_item_count,
215
+ },
216
+ indent=2,
217
+ )
218
+ return CCRToolResult(
219
+ tool_call_id=ccr_call.tool_call_id,
220
+ content=content,
221
+ success=True,
222
+ items_retrieved=entry.original_item_count,
223
+ was_search=False,
224
+ )
225
+ else:
226
+ content = json.dumps(
227
+ {
228
+ "error": "Entry not found or expired (TTL: 5 minutes)",
229
+ "hash": ccr_call.hash_key,
230
+ },
231
+ indent=2,
232
+ )
233
+ return CCRToolResult(
234
+ tool_call_id=ccr_call.tool_call_id,
235
+ content=content,
236
+ success=False,
237
+ )
238
+
239
+ except Exception as e:
240
+ logger.error(f"CCR retrieval failed for {ccr_call.hash_key}: {e}")
241
+ content = json.dumps(
242
+ {
243
+ "error": f"Retrieval failed: {str(e)}",
244
+ "hash": ccr_call.hash_key,
245
+ },
246
+ indent=2,
247
+ )
248
+ return CCRToolResult(
249
+ tool_call_id=ccr_call.tool_call_id,
250
+ content=content,
251
+ success=False,
252
+ )
253
+
254
+ def _create_tool_result_message(
255
+ self,
256
+ results: list[CCRToolResult],
257
+ provider: str,
258
+ ) -> dict[str, Any]:
259
+ """Create a tool result message from CCR results.
260
+
261
+ Args:
262
+ results: List of CCR tool results.
263
+ provider: The provider type.
264
+
265
+ Returns:
266
+ Message dict in the appropriate format.
267
+ """
268
+ if provider == "anthropic":
269
+ # Anthropic: user message with tool_result content blocks
270
+ content_blocks = []
271
+ for result in results:
272
+ content_blocks.append(
273
+ {
274
+ "type": "tool_result",
275
+ "tool_use_id": result.tool_call_id,
276
+ "content": result.content,
277
+ }
278
+ )
279
+ return {
280
+ "role": "user",
281
+ "content": content_blocks,
282
+ }
283
+
284
+ elif provider == "openai":
285
+ # OpenAI: multiple tool messages
286
+ # Actually for OpenAI we return a list of messages
287
+ return {
288
+ "_openai_tool_results": [
289
+ {
290
+ "role": "tool",
291
+ "tool_call_id": result.tool_call_id,
292
+ "content": result.content,
293
+ }
294
+ for result in results
295
+ ]
296
+ }
297
+
298
+ else:
299
+ # Generic format
300
+ return {
301
+ "role": "tool",
302
+ "content": json.dumps(
303
+ [{"tool_call_id": r.tool_call_id, "result": r.content} for r in results]
304
+ ),
305
+ }
306
+
307
+ def _extract_assistant_message(
308
+ self,
309
+ response: dict[str, Any],
310
+ provider: str,
311
+ ) -> dict[str, Any]:
312
+ """Extract the assistant message from an API response.
313
+
314
+ Args:
315
+ response: The API response.
316
+ provider: The provider type.
317
+
318
+ Returns:
319
+ The assistant message dict.
320
+ """
321
+ if provider == "anthropic":
322
+ return {
323
+ "role": "assistant",
324
+ "content": response.get("content", []),
325
+ }
326
+ elif provider == "openai":
327
+ message = response.get("choices", [{}])[0].get("message", {})
328
+ return {
329
+ "role": "assistant",
330
+ "content": message.get("content"),
331
+ "tool_calls": message.get("tool_calls"),
332
+ }
333
+ else:
334
+ return {
335
+ "role": "assistant",
336
+ "content": response.get("content", ""),
337
+ }
338
+
339
+ async def handle_response(
340
+ self,
341
+ response: dict[str, Any],
342
+ messages: list[dict[str, Any]],
343
+ tools: list[dict[str, Any]] | None,
344
+ api_call_fn: Callable[
345
+ [list[dict[str, Any]], list[dict[str, Any]] | None], Awaitable[dict[str, Any]]
346
+ ],
347
+ provider: str = "anthropic",
348
+ ) -> dict[str, Any]:
349
+ """Handle CCR tool calls in a response.
350
+
351
+ This method:
352
+ 1. Detects CCR tool calls
353
+ 2. Executes retrievals
354
+ 3. Continues conversation with tool results
355
+ 4. Repeats until no CCR tool calls remain
356
+
357
+ Args:
358
+ response: The initial API response.
359
+ messages: The conversation messages.
360
+ tools: The tools list (should include CCR tool).
361
+ api_call_fn: Async function to make API calls.
362
+ Signature: (messages, tools) -> response
363
+ provider: The provider type.
364
+
365
+ Returns:
366
+ The final response (with no CCR tool calls).
367
+ """
368
+ if not self.config.enabled:
369
+ return response
370
+
371
+ current_response = response
372
+ current_messages = list(messages) # Copy to avoid mutation
373
+ rounds = 0
374
+
375
+ while rounds < self.config.max_retrieval_rounds:
376
+ # Check for CCR tool calls
377
+ ccr_calls, other_calls = self._parse_ccr_tool_calls(current_response, provider)
378
+
379
+ if not ccr_calls:
380
+ # No CCR tool calls, we're done
381
+ break
382
+
383
+ rounds += 1
384
+ self._retrieval_count += len(ccr_calls)
385
+
386
+ logger.info(f"CCR: Handling {len(ccr_calls)} retrieval(s) in round {rounds}")
387
+
388
+ # Execute all CCR retrievals
389
+ results = [self._execute_retrieval(call) for call in ccr_calls]
390
+
391
+ # Log retrieval stats
392
+ total_items = sum(r.items_retrieved for r in results)
393
+ searches = sum(1 for r in results if r.was_search)
394
+ logger.debug(
395
+ f"CCR: Retrieved {total_items} items "
396
+ f"({searches} searches, {len(results) - searches} full)"
397
+ )
398
+
399
+ # Build continuation messages
400
+ # Add assistant message (the response that had tool calls)
401
+ assistant_msg = self._extract_assistant_message(current_response, provider)
402
+ current_messages.append(assistant_msg)
403
+
404
+ # Add tool results
405
+ tool_result_msg = self._create_tool_result_message(results, provider)
406
+
407
+ if provider == "openai" and "_openai_tool_results" in tool_result_msg:
408
+ # OpenAI uses multiple messages for tool results
409
+ current_messages.extend(tool_result_msg["_openai_tool_results"])
410
+ else:
411
+ current_messages.append(tool_result_msg)
412
+
413
+ # Make continuation API call
414
+ try:
415
+ current_response = await api_call_fn(current_messages, tools)
416
+ except Exception as e:
417
+ logger.error(f"CCR: Continuation API call failed: {e}")
418
+ # Return the response we had (with unhandled CCR calls)
419
+ # The client will see the tool_use and might handle it differently
420
+ break
421
+
422
+ if rounds >= self.config.max_retrieval_rounds:
423
+ logger.warning(
424
+ f"CCR: Hit max retrieval rounds ({self.config.max_retrieval_rounds}), "
425
+ f"returning response with possible unhandled CCR calls"
426
+ )
427
+
428
+ return current_response
429
+
430
+ def get_stats(self) -> dict[str, Any]:
431
+ """Get handler statistics."""
432
+ return {
433
+ "total_retrievals": self._retrieval_count,
434
+ "config": {
435
+ "enabled": self.config.enabled,
436
+ "max_rounds": self.config.max_retrieval_rounds,
437
+ },
438
+ }
439
+
440
+
441
+ @dataclass
442
+ class StreamingCCRBuffer:
443
+ """Buffer for detecting CCR tool calls in streaming responses.
444
+
445
+ Since streaming responses come in chunks, we need to buffer
446
+ until we can detect whether there's a CCR tool call.
447
+
448
+ Strategy:
449
+ 1. Buffer chunks until we see a complete tool_use block
450
+ 2. If it's a CCR call, switch to buffered mode
451
+ 3. Handle CCR and then stream the continuation
452
+ """
453
+
454
+ chunks: list[bytes] = field(default_factory=list)
455
+ detected_ccr: bool = False
456
+ complete_response: dict[str, Any] | None = None
457
+
458
+ # Patterns to detect tool_use in stream
459
+ _tool_use_start: bytes = b'"type":"tool_use"'
460
+ _ccr_tool_pattern: bytes = f'"{CCR_TOOL_NAME}"'.encode()
461
+
462
+ def add_chunk(self, chunk: bytes) -> bool:
463
+ """Add a chunk and check for CCR tool calls.
464
+
465
+ Returns:
466
+ True if CCR tool call detected (should switch to buffered mode).
467
+ """
468
+ self.chunks.append(chunk)
469
+
470
+ # Quick check: does accumulated content contain CCR tool?
471
+ accumulated = b"".join(self.chunks)
472
+
473
+ if self._tool_use_start in accumulated and self._ccr_tool_pattern in accumulated:
474
+ self.detected_ccr = True
475
+ return True
476
+
477
+ return False
478
+
479
+ def get_accumulated(self) -> bytes:
480
+ """Get all accumulated chunks."""
481
+ return b"".join(self.chunks)
482
+
483
+ def clear(self) -> None:
484
+ """Clear the buffer."""
485
+ self.chunks.clear()
486
+ self.detected_ccr = False
487
+ self.complete_response = None
488
+
489
+
490
+ class StreamingCCRHandler:
491
+ """Handle CCR tool calls in streaming responses.
492
+
493
+ For streaming, we have two modes:
494
+ 1. Pass-through: No CCR detected, stream chunks directly
495
+ 2. Buffered: CCR detected, buffer response, handle, then stream result
496
+
497
+ The challenge is we can't know if there's a CCR call until we see
498
+ enough of the response. So we buffer initially, then decide.
499
+ """
500
+
501
+ def __init__(
502
+ self,
503
+ response_handler: CCRResponseHandler,
504
+ provider: str = "anthropic",
505
+ ) -> None:
506
+ self.response_handler = response_handler
507
+ self.provider = provider
508
+ self.buffer = StreamingCCRBuffer()
509
+
510
+ async def process_stream(
511
+ self,
512
+ stream_iterator: Any, # AsyncIterator[bytes]
513
+ messages: list[dict[str, Any]],
514
+ tools: list[dict[str, Any]] | None,
515
+ api_call_fn: Callable[
516
+ [list[dict[str, Any]], list[dict[str, Any]] | None], Awaitable[dict[str, Any]]
517
+ ],
518
+ ) -> Any: # AsyncGenerator[bytes, None]
519
+ """Process a streaming response, handling CCR if needed.
520
+
521
+ This is an async generator that yields chunks.
522
+ If CCR is detected, it buffers, handles, and re-streams.
523
+
524
+ Args:
525
+ stream_iterator: Async iterator of response chunks.
526
+ messages: The conversation messages.
527
+ tools: The tools list.
528
+ api_call_fn: Function to make API calls for continuation.
529
+
530
+ Yields:
531
+ Response chunks (possibly from continuation response).
532
+ """
533
+ # Phase 1: Initial detection
534
+ # Buffer chunks until we can determine if there's a CCR call
535
+ detection_complete = False
536
+
537
+ async for chunk in stream_iterator:
538
+ self.buffer.add_chunk(chunk)
539
+
540
+ # Check if we can determine CCR presence
541
+ # For Anthropic, tool_use blocks come after text content
542
+ # We need to see the stop_reason to know if there's a tool call
543
+ accumulated = self.buffer.get_accumulated()
544
+
545
+ # Look for stream end markers
546
+ if b'"stop_reason"' in accumulated:
547
+ detection_complete = True
548
+
549
+ if self.buffer.detected_ccr:
550
+ # CCR detected - need to handle
551
+ break
552
+ else:
553
+ # No CCR - yield all buffered chunks
554
+ for buffered_chunk in self.buffer.chunks:
555
+ yield buffered_chunk
556
+ self.buffer.clear()
557
+
558
+ # If we haven't detected anything yet and buffer is large,
559
+ # start yielding (response is probably just text)
560
+ elif len(accumulated) > 10000 and not self.buffer.detected_ccr:
561
+ for buffered_chunk in self.buffer.chunks:
562
+ yield buffered_chunk
563
+ self.buffer.clear()
564
+
565
+ # Continue streaming rest of response
566
+ if not detection_complete and not self.buffer.detected_ccr:
567
+ async for chunk in stream_iterator:
568
+ if self.buffer.detected_ccr:
569
+ self.buffer.add_chunk(chunk)
570
+ else:
571
+ yield chunk
572
+
573
+ # Phase 2: Handle CCR if detected
574
+ if self.buffer.detected_ccr:
575
+ logger.info("CCR: Detected tool call in stream, switching to buffered mode")
576
+
577
+ # Collect rest of stream
578
+ async for chunk in stream_iterator:
579
+ self.buffer.add_chunk(chunk)
580
+
581
+ # Parse the complete response
582
+ try:
583
+ # For SSE streams, we need to parse the accumulated data
584
+ complete_data = self._parse_sse_stream(self.buffer.get_accumulated())
585
+
586
+ # Handle CCR
587
+ final_response = await self.response_handler.handle_response(
588
+ complete_data,
589
+ messages,
590
+ tools,
591
+ api_call_fn,
592
+ self.provider,
593
+ )
594
+
595
+ # Re-stream the final response
596
+ # Convert back to SSE format
597
+ async for chunk in self._response_to_sse(final_response):
598
+ yield chunk
599
+
600
+ except Exception as e:
601
+ logger.error(f"CCR: Failed to handle streamed CCR: {e}")
602
+ # Fall back to yielding original buffered content
603
+ yield self.buffer.get_accumulated()
604
+
605
+ def _parse_sse_stream(self, data: bytes) -> dict[str, Any]:
606
+ """Parse SSE stream data into a response dict.
607
+
608
+ SSE format: data: {...}\n\n
609
+ """
610
+ # Accumulate all event data
611
+ events = []
612
+ for line in data.decode("utf-8", errors="replace").split("\n"):
613
+ if line.startswith("data: "):
614
+ event_data = line[6:]
615
+ if event_data.strip() and event_data.strip() != "[DONE]":
616
+ try:
617
+ events.append(json.loads(event_data))
618
+ except json.JSONDecodeError:
619
+ pass
620
+
621
+ # Reconstruct response from events
622
+ # This is provider-specific
623
+ if self.provider == "anthropic":
624
+ return self._reconstruct_anthropic_response(events)
625
+ else:
626
+ return self._reconstruct_openai_response(events)
627
+
628
+ def _reconstruct_anthropic_response(
629
+ self,
630
+ events: list[dict[str, Any]],
631
+ ) -> dict[str, Any]:
632
+ """Reconstruct Anthropic response from stream events."""
633
+ response: dict[str, Any] = {
634
+ "content": [],
635
+ "stop_reason": None,
636
+ "usage": {},
637
+ }
638
+
639
+ current_text = ""
640
+ current_tool: dict[str, Any] | None = None
641
+
642
+ for event in events:
643
+ event_type = event.get("type", "")
644
+
645
+ if event_type == "content_block_start":
646
+ block = event.get("content_block", {})
647
+ if block.get("type") == "text":
648
+ current_text = block.get("text", "")
649
+ elif block.get("type") == "tool_use":
650
+ current_tool = {
651
+ "type": "tool_use",
652
+ "id": block.get("id", ""),
653
+ "name": block.get("name", ""),
654
+ "input": {},
655
+ }
656
+
657
+ elif event_type == "content_block_delta":
658
+ delta = event.get("delta", {})
659
+ if delta.get("type") == "text_delta":
660
+ current_text += delta.get("text", "")
661
+ elif delta.get("type") == "input_json_delta":
662
+ # Accumulate JSON for tool input
663
+ if current_tool is not None:
664
+ partial = delta.get("partial_json", "")
665
+ # This is tricky - partial JSON needs accumulation
666
+ # For simplicity, we'll try to parse when complete
667
+ current_tool["_partial_json"] = (
668
+ current_tool.get("_partial_json", "") + partial
669
+ )
670
+
671
+ elif event_type == "content_block_stop":
672
+ if current_text:
673
+ response["content"].append(
674
+ {
675
+ "type": "text",
676
+ "text": current_text,
677
+ }
678
+ )
679
+ current_text = ""
680
+ if current_tool:
681
+ # Parse accumulated JSON
682
+ partial = current_tool.pop("_partial_json", "")
683
+ if partial:
684
+ try:
685
+ current_tool["input"] = json.loads(partial)
686
+ except json.JSONDecodeError:
687
+ current_tool["input"] = {}
688
+ response["content"].append(current_tool)
689
+ current_tool = None
690
+
691
+ elif event_type == "message_delta":
692
+ delta = event.get("delta", {})
693
+ if "stop_reason" in delta:
694
+ response["stop_reason"] = delta["stop_reason"]
695
+
696
+ elif event_type == "message_stop":
697
+ pass
698
+
699
+ return response
700
+
701
+ def _reconstruct_openai_response(
702
+ self,
703
+ events: list[dict[str, Any]],
704
+ ) -> dict[str, Any]:
705
+ """Reconstruct OpenAI response from stream events."""
706
+ message: dict[str, Any] = {
707
+ "role": "assistant",
708
+ "content": "",
709
+ "tool_calls": [],
710
+ }
711
+
712
+ tool_calls_map: dict[int, dict[str, Any]] = {}
713
+
714
+ for event in events:
715
+ choices = event.get("choices", [])
716
+ if not choices:
717
+ continue
718
+
719
+ delta = choices[0].get("delta", {})
720
+
721
+ if "content" in delta and delta["content"]:
722
+ message["content"] = (message.get("content") or "") + delta["content"]
723
+
724
+ if "tool_calls" in delta:
725
+ for tc_delta in delta["tool_calls"]:
726
+ idx = tc_delta.get("index", 0)
727
+ if idx not in tool_calls_map:
728
+ tool_calls_map[idx] = {
729
+ "id": "",
730
+ "type": "function",
731
+ "function": {"name": "", "arguments": ""},
732
+ }
733
+
734
+ tc = tool_calls_map[idx]
735
+ if "id" in tc_delta:
736
+ tc["id"] = tc_delta["id"]
737
+ if "function" in tc_delta:
738
+ fn = tc_delta["function"]
739
+ if "name" in fn:
740
+ tc["function"]["name"] = fn["name"]
741
+ if "arguments" in fn:
742
+ tc["function"]["arguments"] += fn["arguments"]
743
+
744
+ message["tool_calls"] = [tool_calls_map[i] for i in sorted(tool_calls_map.keys())]
745
+ if not message["tool_calls"]:
746
+ del message["tool_calls"]
747
+ if not message["content"]:
748
+ message["content"] = None
749
+
750
+ return {
751
+ "choices": [{"message": message, "finish_reason": "stop"}],
752
+ }
753
+
754
+ async def _response_to_sse(
755
+ self,
756
+ response: dict[str, Any],
757
+ ) -> Any: # AsyncGenerator[bytes, None]
758
+ """Convert a response back to SSE format for streaming.
759
+
760
+ This is a simplified version - in practice you might want
761
+ to chunk the response more granularly.
762
+ """
763
+ if self.provider == "anthropic":
764
+ # Anthropic SSE format
765
+ yield b"event: message_start\n"
766
+ yield f"data: {json.dumps({'type': 'message_start', 'message': response})}\n\n".encode()
767
+ yield b"event: message_stop\n"
768
+ yield b'data: {"type": "message_stop"}\n\n'
769
+ else:
770
+ # OpenAI SSE format
771
+ yield f"data: {json.dumps(response)}\n\n".encode()
772
+ yield b"data: [DONE]\n\n"