headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,772 @@
|
|
|
1
|
+
"""Response handling for CCR (Compress-Cache-Retrieve).
|
|
2
|
+
|
|
3
|
+
This module provides response interception and CCR tool call handling.
|
|
4
|
+
When the LLM calls headroom_retrieve, this handler:
|
|
5
|
+
1. Detects the tool call in the response
|
|
6
|
+
2. Retrieves content from the compression store
|
|
7
|
+
3. Continues the conversation with the tool result
|
|
8
|
+
4. Returns the final response to the client
|
|
9
|
+
|
|
10
|
+
This solves the critical gap where the proxy injects the tool but
|
|
11
|
+
can't handle the LLM's tool calls.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Awaitable, Callable
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
from ..cache.compression_store import get_compression_store
|
|
23
|
+
from .tool_injection import CCR_TOOL_NAME, parse_tool_call
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class CCRToolCall:
|
|
30
|
+
"""Represents a detected CCR tool call."""
|
|
31
|
+
|
|
32
|
+
tool_call_id: str
|
|
33
|
+
hash_key: str
|
|
34
|
+
query: str | None = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class CCRToolResult:
|
|
39
|
+
"""Result of handling a CCR tool call."""
|
|
40
|
+
|
|
41
|
+
tool_call_id: str
|
|
42
|
+
content: str
|
|
43
|
+
success: bool
|
|
44
|
+
items_retrieved: int = 0
|
|
45
|
+
was_search: bool = False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ResponseHandlerConfig:
|
|
50
|
+
"""Configuration for CCR response handling."""
|
|
51
|
+
|
|
52
|
+
# Whether to handle CCR tool calls automatically
|
|
53
|
+
enabled: bool = True
|
|
54
|
+
|
|
55
|
+
# Maximum number of CCR retrieval rounds (prevent infinite loops)
|
|
56
|
+
max_retrieval_rounds: int = 3
|
|
57
|
+
|
|
58
|
+
# Whether to strip CCR tool calls from final response
|
|
59
|
+
strip_ccr_from_response: bool = True
|
|
60
|
+
|
|
61
|
+
# Timeout for continuation requests (ms)
|
|
62
|
+
continuation_timeout_ms: int = 120000
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class CCRResponseHandler:
|
|
66
|
+
"""Handles CCR tool calls in LLM responses.
|
|
67
|
+
|
|
68
|
+
This handler intercepts responses, detects CCR tool calls,
|
|
69
|
+
retrieves content, and continues the conversation until
|
|
70
|
+
the LLM produces a response without CCR tool calls.
|
|
71
|
+
|
|
72
|
+
Example flow:
|
|
73
|
+
1. LLM response contains: tool_use(headroom_retrieve, hash=abc123)
|
|
74
|
+
2. Handler detects this, retrieves original content
|
|
75
|
+
3. Handler makes another API call with tool result
|
|
76
|
+
4. LLM responds with actual content (no CCR tool call)
|
|
77
|
+
5. Handler returns this final response
|
|
78
|
+
|
|
79
|
+
Usage:
|
|
80
|
+
handler = CCRResponseHandler(config)
|
|
81
|
+
|
|
82
|
+
# Check if response needs handling
|
|
83
|
+
if handler.has_ccr_tool_calls(response_json):
|
|
84
|
+
# Handle the tool calls
|
|
85
|
+
final_response = await handler.handle_response(
|
|
86
|
+
response_json,
|
|
87
|
+
messages,
|
|
88
|
+
tools,
|
|
89
|
+
api_call_fn,
|
|
90
|
+
provider="anthropic"
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
final_response = response_json
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(self, config: ResponseHandlerConfig | None = None):
|
|
97
|
+
self.config = config or ResponseHandlerConfig()
|
|
98
|
+
self._retrieval_count = 0
|
|
99
|
+
|
|
100
|
+
def has_ccr_tool_calls(
|
|
101
|
+
self,
|
|
102
|
+
response: dict[str, Any],
|
|
103
|
+
provider: str = "anthropic",
|
|
104
|
+
) -> bool:
|
|
105
|
+
"""Check if response contains CCR tool calls.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
response: The API response JSON.
|
|
109
|
+
provider: The provider type.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
True if response contains headroom_retrieve tool calls.
|
|
113
|
+
"""
|
|
114
|
+
tool_calls = self._extract_tool_calls(response, provider)
|
|
115
|
+
return any(
|
|
116
|
+
tc.get("name") == CCR_TOOL_NAME or tc.get("function", {}).get("name") == CCR_TOOL_NAME
|
|
117
|
+
for tc in tool_calls
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def _extract_tool_calls(
|
|
121
|
+
self,
|
|
122
|
+
response: dict[str, Any],
|
|
123
|
+
provider: str,
|
|
124
|
+
) -> list[dict[str, Any]]:
|
|
125
|
+
"""Extract tool calls from response based on provider format."""
|
|
126
|
+
if provider == "anthropic":
|
|
127
|
+
# Anthropic format: content blocks with type=tool_use
|
|
128
|
+
content = response.get("content", [])
|
|
129
|
+
if isinstance(content, list):
|
|
130
|
+
return [block for block in content if block.get("type") == "tool_use"]
|
|
131
|
+
return []
|
|
132
|
+
|
|
133
|
+
elif provider == "openai":
|
|
134
|
+
# OpenAI format: message.tool_calls array
|
|
135
|
+
message = response.get("choices", [{}])[0].get("message", {})
|
|
136
|
+
tool_calls = message.get("tool_calls", [])
|
|
137
|
+
return list(tool_calls) if tool_calls else []
|
|
138
|
+
|
|
139
|
+
return []
|
|
140
|
+
|
|
141
|
+
def _parse_ccr_tool_calls(
|
|
142
|
+
self,
|
|
143
|
+
response: dict[str, Any],
|
|
144
|
+
provider: str,
|
|
145
|
+
) -> tuple[list[CCRToolCall], list[dict[str, Any]]]:
|
|
146
|
+
"""Parse CCR tool calls from response, separate from other tool calls.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Tuple of (ccr_tool_calls, other_tool_calls)
|
|
150
|
+
"""
|
|
151
|
+
all_tool_calls = self._extract_tool_calls(response, provider)
|
|
152
|
+
|
|
153
|
+
ccr_calls = []
|
|
154
|
+
other_calls = []
|
|
155
|
+
|
|
156
|
+
for tc in all_tool_calls:
|
|
157
|
+
hash_key, query = parse_tool_call(tc, provider)
|
|
158
|
+
|
|
159
|
+
if hash_key is not None:
|
|
160
|
+
# This is a CCR tool call
|
|
161
|
+
tool_call_id = tc.get("id", "")
|
|
162
|
+
ccr_calls.append(
|
|
163
|
+
CCRToolCall(
|
|
164
|
+
tool_call_id=tool_call_id,
|
|
165
|
+
hash_key=hash_key,
|
|
166
|
+
query=query,
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
170
|
+
# Not a CCR tool call
|
|
171
|
+
other_calls.append(tc)
|
|
172
|
+
|
|
173
|
+
return ccr_calls, other_calls
|
|
174
|
+
|
|
175
|
+
def _execute_retrieval(self, ccr_call: CCRToolCall) -> CCRToolResult:
|
|
176
|
+
"""Execute a CCR retrieval.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
ccr_call: The CCR tool call to execute.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
CCRToolResult with the retrieved content.
|
|
183
|
+
"""
|
|
184
|
+
store = get_compression_store()
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
if ccr_call.query:
|
|
188
|
+
# Search within compressed content
|
|
189
|
+
results = store.search(ccr_call.hash_key, ccr_call.query)
|
|
190
|
+
content = json.dumps(
|
|
191
|
+
{
|
|
192
|
+
"hash": ccr_call.hash_key,
|
|
193
|
+
"query": ccr_call.query,
|
|
194
|
+
"results": results,
|
|
195
|
+
"count": len(results),
|
|
196
|
+
},
|
|
197
|
+
indent=2,
|
|
198
|
+
)
|
|
199
|
+
return CCRToolResult(
|
|
200
|
+
tool_call_id=ccr_call.tool_call_id,
|
|
201
|
+
content=content,
|
|
202
|
+
success=True,
|
|
203
|
+
items_retrieved=len(results),
|
|
204
|
+
was_search=True,
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
# Full retrieval
|
|
208
|
+
entry = store.retrieve(ccr_call.hash_key)
|
|
209
|
+
if entry:
|
|
210
|
+
content = json.dumps(
|
|
211
|
+
{
|
|
212
|
+
"hash": ccr_call.hash_key,
|
|
213
|
+
"original_content": entry.original_content,
|
|
214
|
+
"original_item_count": entry.original_item_count,
|
|
215
|
+
},
|
|
216
|
+
indent=2,
|
|
217
|
+
)
|
|
218
|
+
return CCRToolResult(
|
|
219
|
+
tool_call_id=ccr_call.tool_call_id,
|
|
220
|
+
content=content,
|
|
221
|
+
success=True,
|
|
222
|
+
items_retrieved=entry.original_item_count,
|
|
223
|
+
was_search=False,
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
content = json.dumps(
|
|
227
|
+
{
|
|
228
|
+
"error": "Entry not found or expired (TTL: 5 minutes)",
|
|
229
|
+
"hash": ccr_call.hash_key,
|
|
230
|
+
},
|
|
231
|
+
indent=2,
|
|
232
|
+
)
|
|
233
|
+
return CCRToolResult(
|
|
234
|
+
tool_call_id=ccr_call.tool_call_id,
|
|
235
|
+
content=content,
|
|
236
|
+
success=False,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
except Exception as e:
|
|
240
|
+
logger.error(f"CCR retrieval failed for {ccr_call.hash_key}: {e}")
|
|
241
|
+
content = json.dumps(
|
|
242
|
+
{
|
|
243
|
+
"error": f"Retrieval failed: {str(e)}",
|
|
244
|
+
"hash": ccr_call.hash_key,
|
|
245
|
+
},
|
|
246
|
+
indent=2,
|
|
247
|
+
)
|
|
248
|
+
return CCRToolResult(
|
|
249
|
+
tool_call_id=ccr_call.tool_call_id,
|
|
250
|
+
content=content,
|
|
251
|
+
success=False,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def _create_tool_result_message(
|
|
255
|
+
self,
|
|
256
|
+
results: list[CCRToolResult],
|
|
257
|
+
provider: str,
|
|
258
|
+
) -> dict[str, Any]:
|
|
259
|
+
"""Create a tool result message from CCR results.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
results: List of CCR tool results.
|
|
263
|
+
provider: The provider type.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Message dict in the appropriate format.
|
|
267
|
+
"""
|
|
268
|
+
if provider == "anthropic":
|
|
269
|
+
# Anthropic: user message with tool_result content blocks
|
|
270
|
+
content_blocks = []
|
|
271
|
+
for result in results:
|
|
272
|
+
content_blocks.append(
|
|
273
|
+
{
|
|
274
|
+
"type": "tool_result",
|
|
275
|
+
"tool_use_id": result.tool_call_id,
|
|
276
|
+
"content": result.content,
|
|
277
|
+
}
|
|
278
|
+
)
|
|
279
|
+
return {
|
|
280
|
+
"role": "user",
|
|
281
|
+
"content": content_blocks,
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
elif provider == "openai":
|
|
285
|
+
# OpenAI: multiple tool messages
|
|
286
|
+
# Actually for OpenAI we return a list of messages
|
|
287
|
+
return {
|
|
288
|
+
"_openai_tool_results": [
|
|
289
|
+
{
|
|
290
|
+
"role": "tool",
|
|
291
|
+
"tool_call_id": result.tool_call_id,
|
|
292
|
+
"content": result.content,
|
|
293
|
+
}
|
|
294
|
+
for result in results
|
|
295
|
+
]
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
else:
|
|
299
|
+
# Generic format
|
|
300
|
+
return {
|
|
301
|
+
"role": "tool",
|
|
302
|
+
"content": json.dumps(
|
|
303
|
+
[{"tool_call_id": r.tool_call_id, "result": r.content} for r in results]
|
|
304
|
+
),
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
def _extract_assistant_message(
|
|
308
|
+
self,
|
|
309
|
+
response: dict[str, Any],
|
|
310
|
+
provider: str,
|
|
311
|
+
) -> dict[str, Any]:
|
|
312
|
+
"""Extract the assistant message from an API response.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
response: The API response.
|
|
316
|
+
provider: The provider type.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
The assistant message dict.
|
|
320
|
+
"""
|
|
321
|
+
if provider == "anthropic":
|
|
322
|
+
return {
|
|
323
|
+
"role": "assistant",
|
|
324
|
+
"content": response.get("content", []),
|
|
325
|
+
}
|
|
326
|
+
elif provider == "openai":
|
|
327
|
+
message = response.get("choices", [{}])[0].get("message", {})
|
|
328
|
+
return {
|
|
329
|
+
"role": "assistant",
|
|
330
|
+
"content": message.get("content"),
|
|
331
|
+
"tool_calls": message.get("tool_calls"),
|
|
332
|
+
}
|
|
333
|
+
else:
|
|
334
|
+
return {
|
|
335
|
+
"role": "assistant",
|
|
336
|
+
"content": response.get("content", ""),
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
async def handle_response(
|
|
340
|
+
self,
|
|
341
|
+
response: dict[str, Any],
|
|
342
|
+
messages: list[dict[str, Any]],
|
|
343
|
+
tools: list[dict[str, Any]] | None,
|
|
344
|
+
api_call_fn: Callable[
|
|
345
|
+
[list[dict[str, Any]], list[dict[str, Any]] | None], Awaitable[dict[str, Any]]
|
|
346
|
+
],
|
|
347
|
+
provider: str = "anthropic",
|
|
348
|
+
) -> dict[str, Any]:
|
|
349
|
+
"""Handle CCR tool calls in a response.
|
|
350
|
+
|
|
351
|
+
This method:
|
|
352
|
+
1. Detects CCR tool calls
|
|
353
|
+
2. Executes retrievals
|
|
354
|
+
3. Continues conversation with tool results
|
|
355
|
+
4. Repeats until no CCR tool calls remain
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
response: The initial API response.
|
|
359
|
+
messages: The conversation messages.
|
|
360
|
+
tools: The tools list (should include CCR tool).
|
|
361
|
+
api_call_fn: Async function to make API calls.
|
|
362
|
+
Signature: (messages, tools) -> response
|
|
363
|
+
provider: The provider type.
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
The final response (with no CCR tool calls).
|
|
367
|
+
"""
|
|
368
|
+
if not self.config.enabled:
|
|
369
|
+
return response
|
|
370
|
+
|
|
371
|
+
current_response = response
|
|
372
|
+
current_messages = list(messages) # Copy to avoid mutation
|
|
373
|
+
rounds = 0
|
|
374
|
+
|
|
375
|
+
while rounds < self.config.max_retrieval_rounds:
|
|
376
|
+
# Check for CCR tool calls
|
|
377
|
+
ccr_calls, other_calls = self._parse_ccr_tool_calls(current_response, provider)
|
|
378
|
+
|
|
379
|
+
if not ccr_calls:
|
|
380
|
+
# No CCR tool calls, we're done
|
|
381
|
+
break
|
|
382
|
+
|
|
383
|
+
rounds += 1
|
|
384
|
+
self._retrieval_count += len(ccr_calls)
|
|
385
|
+
|
|
386
|
+
logger.info(f"CCR: Handling {len(ccr_calls)} retrieval(s) in round {rounds}")
|
|
387
|
+
|
|
388
|
+
# Execute all CCR retrievals
|
|
389
|
+
results = [self._execute_retrieval(call) for call in ccr_calls]
|
|
390
|
+
|
|
391
|
+
# Log retrieval stats
|
|
392
|
+
total_items = sum(r.items_retrieved for r in results)
|
|
393
|
+
searches = sum(1 for r in results if r.was_search)
|
|
394
|
+
logger.debug(
|
|
395
|
+
f"CCR: Retrieved {total_items} items "
|
|
396
|
+
f"({searches} searches, {len(results) - searches} full)"
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Build continuation messages
|
|
400
|
+
# Add assistant message (the response that had tool calls)
|
|
401
|
+
assistant_msg = self._extract_assistant_message(current_response, provider)
|
|
402
|
+
current_messages.append(assistant_msg)
|
|
403
|
+
|
|
404
|
+
# Add tool results
|
|
405
|
+
tool_result_msg = self._create_tool_result_message(results, provider)
|
|
406
|
+
|
|
407
|
+
if provider == "openai" and "_openai_tool_results" in tool_result_msg:
|
|
408
|
+
# OpenAI uses multiple messages for tool results
|
|
409
|
+
current_messages.extend(tool_result_msg["_openai_tool_results"])
|
|
410
|
+
else:
|
|
411
|
+
current_messages.append(tool_result_msg)
|
|
412
|
+
|
|
413
|
+
# Make continuation API call
|
|
414
|
+
try:
|
|
415
|
+
current_response = await api_call_fn(current_messages, tools)
|
|
416
|
+
except Exception as e:
|
|
417
|
+
logger.error(f"CCR: Continuation API call failed: {e}")
|
|
418
|
+
# Return the response we had (with unhandled CCR calls)
|
|
419
|
+
# The client will see the tool_use and might handle it differently
|
|
420
|
+
break
|
|
421
|
+
|
|
422
|
+
if rounds >= self.config.max_retrieval_rounds:
|
|
423
|
+
logger.warning(
|
|
424
|
+
f"CCR: Hit max retrieval rounds ({self.config.max_retrieval_rounds}), "
|
|
425
|
+
f"returning response with possible unhandled CCR calls"
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
return current_response
|
|
429
|
+
|
|
430
|
+
def get_stats(self) -> dict[str, Any]:
|
|
431
|
+
"""Get handler statistics."""
|
|
432
|
+
return {
|
|
433
|
+
"total_retrievals": self._retrieval_count,
|
|
434
|
+
"config": {
|
|
435
|
+
"enabled": self.config.enabled,
|
|
436
|
+
"max_rounds": self.config.max_retrieval_rounds,
|
|
437
|
+
},
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
@dataclass
|
|
442
|
+
class StreamingCCRBuffer:
|
|
443
|
+
"""Buffer for detecting CCR tool calls in streaming responses.
|
|
444
|
+
|
|
445
|
+
Since streaming responses come in chunks, we need to buffer
|
|
446
|
+
until we can detect whether there's a CCR tool call.
|
|
447
|
+
|
|
448
|
+
Strategy:
|
|
449
|
+
1. Buffer chunks until we see a complete tool_use block
|
|
450
|
+
2. If it's a CCR call, switch to buffered mode
|
|
451
|
+
3. Handle CCR and then stream the continuation
|
|
452
|
+
"""
|
|
453
|
+
|
|
454
|
+
chunks: list[bytes] = field(default_factory=list)
|
|
455
|
+
detected_ccr: bool = False
|
|
456
|
+
complete_response: dict[str, Any] | None = None
|
|
457
|
+
|
|
458
|
+
# Patterns to detect tool_use in stream
|
|
459
|
+
_tool_use_start: bytes = b'"type":"tool_use"'
|
|
460
|
+
_ccr_tool_pattern: bytes = f'"{CCR_TOOL_NAME}"'.encode()
|
|
461
|
+
|
|
462
|
+
def add_chunk(self, chunk: bytes) -> bool:
|
|
463
|
+
"""Add a chunk and check for CCR tool calls.
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
True if CCR tool call detected (should switch to buffered mode).
|
|
467
|
+
"""
|
|
468
|
+
self.chunks.append(chunk)
|
|
469
|
+
|
|
470
|
+
# Quick check: does accumulated content contain CCR tool?
|
|
471
|
+
accumulated = b"".join(self.chunks)
|
|
472
|
+
|
|
473
|
+
if self._tool_use_start in accumulated and self._ccr_tool_pattern in accumulated:
|
|
474
|
+
self.detected_ccr = True
|
|
475
|
+
return True
|
|
476
|
+
|
|
477
|
+
return False
|
|
478
|
+
|
|
479
|
+
def get_accumulated(self) -> bytes:
|
|
480
|
+
"""Get all accumulated chunks."""
|
|
481
|
+
return b"".join(self.chunks)
|
|
482
|
+
|
|
483
|
+
def clear(self) -> None:
|
|
484
|
+
"""Clear the buffer."""
|
|
485
|
+
self.chunks.clear()
|
|
486
|
+
self.detected_ccr = False
|
|
487
|
+
self.complete_response = None
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
class StreamingCCRHandler:
|
|
491
|
+
"""Handle CCR tool calls in streaming responses.
|
|
492
|
+
|
|
493
|
+
For streaming, we have two modes:
|
|
494
|
+
1. Pass-through: No CCR detected, stream chunks directly
|
|
495
|
+
2. Buffered: CCR detected, buffer response, handle, then stream result
|
|
496
|
+
|
|
497
|
+
The challenge is we can't know if there's a CCR call until we see
|
|
498
|
+
enough of the response. So we buffer initially, then decide.
|
|
499
|
+
"""
|
|
500
|
+
|
|
501
|
+
def __init__(
|
|
502
|
+
self,
|
|
503
|
+
response_handler: CCRResponseHandler,
|
|
504
|
+
provider: str = "anthropic",
|
|
505
|
+
) -> None:
|
|
506
|
+
self.response_handler = response_handler
|
|
507
|
+
self.provider = provider
|
|
508
|
+
self.buffer = StreamingCCRBuffer()
|
|
509
|
+
|
|
510
|
+
async def process_stream(
|
|
511
|
+
self,
|
|
512
|
+
stream_iterator: Any, # AsyncIterator[bytes]
|
|
513
|
+
messages: list[dict[str, Any]],
|
|
514
|
+
tools: list[dict[str, Any]] | None,
|
|
515
|
+
api_call_fn: Callable[
|
|
516
|
+
[list[dict[str, Any]], list[dict[str, Any]] | None], Awaitable[dict[str, Any]]
|
|
517
|
+
],
|
|
518
|
+
) -> Any: # AsyncGenerator[bytes, None]
|
|
519
|
+
"""Process a streaming response, handling CCR if needed.
|
|
520
|
+
|
|
521
|
+
This is an async generator that yields chunks.
|
|
522
|
+
If CCR is detected, it buffers, handles, and re-streams.
|
|
523
|
+
|
|
524
|
+
Args:
|
|
525
|
+
stream_iterator: Async iterator of response chunks.
|
|
526
|
+
messages: The conversation messages.
|
|
527
|
+
tools: The tools list.
|
|
528
|
+
api_call_fn: Function to make API calls for continuation.
|
|
529
|
+
|
|
530
|
+
Yields:
|
|
531
|
+
Response chunks (possibly from continuation response).
|
|
532
|
+
"""
|
|
533
|
+
# Phase 1: Initial detection
|
|
534
|
+
# Buffer chunks until we can determine if there's a CCR call
|
|
535
|
+
detection_complete = False
|
|
536
|
+
|
|
537
|
+
async for chunk in stream_iterator:
|
|
538
|
+
self.buffer.add_chunk(chunk)
|
|
539
|
+
|
|
540
|
+
# Check if we can determine CCR presence
|
|
541
|
+
# For Anthropic, tool_use blocks come after text content
|
|
542
|
+
# We need to see the stop_reason to know if there's a tool call
|
|
543
|
+
accumulated = self.buffer.get_accumulated()
|
|
544
|
+
|
|
545
|
+
# Look for stream end markers
|
|
546
|
+
if b'"stop_reason"' in accumulated:
|
|
547
|
+
detection_complete = True
|
|
548
|
+
|
|
549
|
+
if self.buffer.detected_ccr:
|
|
550
|
+
# CCR detected - need to handle
|
|
551
|
+
break
|
|
552
|
+
else:
|
|
553
|
+
# No CCR - yield all buffered chunks
|
|
554
|
+
for buffered_chunk in self.buffer.chunks:
|
|
555
|
+
yield buffered_chunk
|
|
556
|
+
self.buffer.clear()
|
|
557
|
+
|
|
558
|
+
# If we haven't detected anything yet and buffer is large,
|
|
559
|
+
# start yielding (response is probably just text)
|
|
560
|
+
elif len(accumulated) > 10000 and not self.buffer.detected_ccr:
|
|
561
|
+
for buffered_chunk in self.buffer.chunks:
|
|
562
|
+
yield buffered_chunk
|
|
563
|
+
self.buffer.clear()
|
|
564
|
+
|
|
565
|
+
# Continue streaming rest of response
|
|
566
|
+
if not detection_complete and not self.buffer.detected_ccr:
|
|
567
|
+
async for chunk in stream_iterator:
|
|
568
|
+
if self.buffer.detected_ccr:
|
|
569
|
+
self.buffer.add_chunk(chunk)
|
|
570
|
+
else:
|
|
571
|
+
yield chunk
|
|
572
|
+
|
|
573
|
+
# Phase 2: Handle CCR if detected
|
|
574
|
+
if self.buffer.detected_ccr:
|
|
575
|
+
logger.info("CCR: Detected tool call in stream, switching to buffered mode")
|
|
576
|
+
|
|
577
|
+
# Collect rest of stream
|
|
578
|
+
async for chunk in stream_iterator:
|
|
579
|
+
self.buffer.add_chunk(chunk)
|
|
580
|
+
|
|
581
|
+
# Parse the complete response
|
|
582
|
+
try:
|
|
583
|
+
# For SSE streams, we need to parse the accumulated data
|
|
584
|
+
complete_data = self._parse_sse_stream(self.buffer.get_accumulated())
|
|
585
|
+
|
|
586
|
+
# Handle CCR
|
|
587
|
+
final_response = await self.response_handler.handle_response(
|
|
588
|
+
complete_data,
|
|
589
|
+
messages,
|
|
590
|
+
tools,
|
|
591
|
+
api_call_fn,
|
|
592
|
+
self.provider,
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
# Re-stream the final response
|
|
596
|
+
# Convert back to SSE format
|
|
597
|
+
async for chunk in self._response_to_sse(final_response):
|
|
598
|
+
yield chunk
|
|
599
|
+
|
|
600
|
+
except Exception as e:
|
|
601
|
+
logger.error(f"CCR: Failed to handle streamed CCR: {e}")
|
|
602
|
+
# Fall back to yielding original buffered content
|
|
603
|
+
yield self.buffer.get_accumulated()
|
|
604
|
+
|
|
605
|
+
def _parse_sse_stream(self, data: bytes) -> dict[str, Any]:
|
|
606
|
+
"""Parse SSE stream data into a response dict.
|
|
607
|
+
|
|
608
|
+
SSE format: data: {...}\n\n
|
|
609
|
+
"""
|
|
610
|
+
# Accumulate all event data
|
|
611
|
+
events = []
|
|
612
|
+
for line in data.decode("utf-8", errors="replace").split("\n"):
|
|
613
|
+
if line.startswith("data: "):
|
|
614
|
+
event_data = line[6:]
|
|
615
|
+
if event_data.strip() and event_data.strip() != "[DONE]":
|
|
616
|
+
try:
|
|
617
|
+
events.append(json.loads(event_data))
|
|
618
|
+
except json.JSONDecodeError:
|
|
619
|
+
pass
|
|
620
|
+
|
|
621
|
+
# Reconstruct response from events
|
|
622
|
+
# This is provider-specific
|
|
623
|
+
if self.provider == "anthropic":
|
|
624
|
+
return self._reconstruct_anthropic_response(events)
|
|
625
|
+
else:
|
|
626
|
+
return self._reconstruct_openai_response(events)
|
|
627
|
+
|
|
628
|
+
def _reconstruct_anthropic_response(
|
|
629
|
+
self,
|
|
630
|
+
events: list[dict[str, Any]],
|
|
631
|
+
) -> dict[str, Any]:
|
|
632
|
+
"""Reconstruct Anthropic response from stream events."""
|
|
633
|
+
response: dict[str, Any] = {
|
|
634
|
+
"content": [],
|
|
635
|
+
"stop_reason": None,
|
|
636
|
+
"usage": {},
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
current_text = ""
|
|
640
|
+
current_tool: dict[str, Any] | None = None
|
|
641
|
+
|
|
642
|
+
for event in events:
|
|
643
|
+
event_type = event.get("type", "")
|
|
644
|
+
|
|
645
|
+
if event_type == "content_block_start":
|
|
646
|
+
block = event.get("content_block", {})
|
|
647
|
+
if block.get("type") == "text":
|
|
648
|
+
current_text = block.get("text", "")
|
|
649
|
+
elif block.get("type") == "tool_use":
|
|
650
|
+
current_tool = {
|
|
651
|
+
"type": "tool_use",
|
|
652
|
+
"id": block.get("id", ""),
|
|
653
|
+
"name": block.get("name", ""),
|
|
654
|
+
"input": {},
|
|
655
|
+
}
|
|
656
|
+
|
|
657
|
+
elif event_type == "content_block_delta":
|
|
658
|
+
delta = event.get("delta", {})
|
|
659
|
+
if delta.get("type") == "text_delta":
|
|
660
|
+
current_text += delta.get("text", "")
|
|
661
|
+
elif delta.get("type") == "input_json_delta":
|
|
662
|
+
# Accumulate JSON for tool input
|
|
663
|
+
if current_tool is not None:
|
|
664
|
+
partial = delta.get("partial_json", "")
|
|
665
|
+
# This is tricky - partial JSON needs accumulation
|
|
666
|
+
# For simplicity, we'll try to parse when complete
|
|
667
|
+
current_tool["_partial_json"] = (
|
|
668
|
+
current_tool.get("_partial_json", "") + partial
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
elif event_type == "content_block_stop":
|
|
672
|
+
if current_text:
|
|
673
|
+
response["content"].append(
|
|
674
|
+
{
|
|
675
|
+
"type": "text",
|
|
676
|
+
"text": current_text,
|
|
677
|
+
}
|
|
678
|
+
)
|
|
679
|
+
current_text = ""
|
|
680
|
+
if current_tool:
|
|
681
|
+
# Parse accumulated JSON
|
|
682
|
+
partial = current_tool.pop("_partial_json", "")
|
|
683
|
+
if partial:
|
|
684
|
+
try:
|
|
685
|
+
current_tool["input"] = json.loads(partial)
|
|
686
|
+
except json.JSONDecodeError:
|
|
687
|
+
current_tool["input"] = {}
|
|
688
|
+
response["content"].append(current_tool)
|
|
689
|
+
current_tool = None
|
|
690
|
+
|
|
691
|
+
elif event_type == "message_delta":
|
|
692
|
+
delta = event.get("delta", {})
|
|
693
|
+
if "stop_reason" in delta:
|
|
694
|
+
response["stop_reason"] = delta["stop_reason"]
|
|
695
|
+
|
|
696
|
+
elif event_type == "message_stop":
|
|
697
|
+
pass
|
|
698
|
+
|
|
699
|
+
return response
|
|
700
|
+
|
|
701
|
+
def _reconstruct_openai_response(
|
|
702
|
+
self,
|
|
703
|
+
events: list[dict[str, Any]],
|
|
704
|
+
) -> dict[str, Any]:
|
|
705
|
+
"""Reconstruct OpenAI response from stream events."""
|
|
706
|
+
message: dict[str, Any] = {
|
|
707
|
+
"role": "assistant",
|
|
708
|
+
"content": "",
|
|
709
|
+
"tool_calls": [],
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
tool_calls_map: dict[int, dict[str, Any]] = {}
|
|
713
|
+
|
|
714
|
+
for event in events:
|
|
715
|
+
choices = event.get("choices", [])
|
|
716
|
+
if not choices:
|
|
717
|
+
continue
|
|
718
|
+
|
|
719
|
+
delta = choices[0].get("delta", {})
|
|
720
|
+
|
|
721
|
+
if "content" in delta and delta["content"]:
|
|
722
|
+
message["content"] = (message.get("content") or "") + delta["content"]
|
|
723
|
+
|
|
724
|
+
if "tool_calls" in delta:
|
|
725
|
+
for tc_delta in delta["tool_calls"]:
|
|
726
|
+
idx = tc_delta.get("index", 0)
|
|
727
|
+
if idx not in tool_calls_map:
|
|
728
|
+
tool_calls_map[idx] = {
|
|
729
|
+
"id": "",
|
|
730
|
+
"type": "function",
|
|
731
|
+
"function": {"name": "", "arguments": ""},
|
|
732
|
+
}
|
|
733
|
+
|
|
734
|
+
tc = tool_calls_map[idx]
|
|
735
|
+
if "id" in tc_delta:
|
|
736
|
+
tc["id"] = tc_delta["id"]
|
|
737
|
+
if "function" in tc_delta:
|
|
738
|
+
fn = tc_delta["function"]
|
|
739
|
+
if "name" in fn:
|
|
740
|
+
tc["function"]["name"] = fn["name"]
|
|
741
|
+
if "arguments" in fn:
|
|
742
|
+
tc["function"]["arguments"] += fn["arguments"]
|
|
743
|
+
|
|
744
|
+
message["tool_calls"] = [tool_calls_map[i] for i in sorted(tool_calls_map.keys())]
|
|
745
|
+
if not message["tool_calls"]:
|
|
746
|
+
del message["tool_calls"]
|
|
747
|
+
if not message["content"]:
|
|
748
|
+
message["content"] = None
|
|
749
|
+
|
|
750
|
+
return {
|
|
751
|
+
"choices": [{"message": message, "finish_reason": "stop"}],
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
async def _response_to_sse(
|
|
755
|
+
self,
|
|
756
|
+
response: dict[str, Any],
|
|
757
|
+
) -> Any: # AsyncGenerator[bytes, None]
|
|
758
|
+
"""Convert a response back to SSE format for streaming.
|
|
759
|
+
|
|
760
|
+
This is a simplified version - in practice you might want
|
|
761
|
+
to chunk the response more granularly.
|
|
762
|
+
"""
|
|
763
|
+
if self.provider == "anthropic":
|
|
764
|
+
# Anthropic SSE format
|
|
765
|
+
yield b"event: message_start\n"
|
|
766
|
+
yield f"data: {json.dumps({'type': 'message_start', 'message': response})}\n\n".encode()
|
|
767
|
+
yield b"event: message_stop\n"
|
|
768
|
+
yield b'data: {"type": "message_stop"}\n\n'
|
|
769
|
+
else:
|
|
770
|
+
# OpenAI SSE format
|
|
771
|
+
yield f"data: {json.dumps(response)}\n\n".encode()
|
|
772
|
+
yield b"data: [DONE]\n\n"
|