sqlsaber 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/anthropic.py +283 -176
- sqlsaber/agents/base.py +11 -11
- sqlsaber/agents/streaming.py +3 -3
- sqlsaber/cli/auth.py +142 -0
- sqlsaber/cli/commands.py +9 -4
- sqlsaber/cli/completers.py +3 -5
- sqlsaber/cli/database.py +9 -10
- sqlsaber/cli/display.py +5 -7
- sqlsaber/cli/interactive.py +2 -3
- sqlsaber/cli/memory.py +7 -9
- sqlsaber/cli/models.py +1 -2
- sqlsaber/cli/streaming.py +5 -31
- sqlsaber/clients/__init__.py +6 -0
- sqlsaber/clients/anthropic.py +285 -0
- sqlsaber/clients/base.py +31 -0
- sqlsaber/clients/exceptions.py +117 -0
- sqlsaber/clients/models.py +282 -0
- sqlsaber/clients/streaming.py +257 -0
- sqlsaber/config/api_keys.py +2 -3
- sqlsaber/config/auth.py +86 -0
- sqlsaber/config/database.py +20 -20
- sqlsaber/config/oauth_flow.py +274 -0
- sqlsaber/config/oauth_tokens.py +175 -0
- sqlsaber/config/settings.py +34 -23
- sqlsaber/database/connection.py +9 -9
- sqlsaber/database/schema.py +25 -25
- sqlsaber/mcp/mcp.py +3 -4
- sqlsaber/memory/manager.py +3 -5
- sqlsaber/memory/storage.py +7 -8
- sqlsaber/models/events.py +4 -4
- sqlsaber/models/types.py +10 -10
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.0.dist-info}/METADATA +1 -1
- sqlsaber-0.8.0.dist-info/RECORD +46 -0
- sqlsaber-0.7.0.dist-info/RECORD +0 -36
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/anthropic.py
CHANGED
|
@@ -1,25 +1,34 @@
|
|
|
1
|
-
"""Anthropic-specific SQL agent implementation."""
|
|
1
|
+
"""Anthropic-specific SQL agent implementation using the custom client."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import json
|
|
5
|
-
from typing import Any, AsyncIterator
|
|
6
|
-
|
|
7
|
-
from anthropic import AsyncAnthropic
|
|
5
|
+
from typing import Any, AsyncIterator
|
|
8
6
|
|
|
9
7
|
from sqlsaber.agents.base import BaseSQLAgent
|
|
10
8
|
from sqlsaber.agents.streaming import (
|
|
11
|
-
StreamingResponse,
|
|
12
9
|
build_tool_result_block,
|
|
13
10
|
)
|
|
11
|
+
from sqlsaber.clients import AnthropicClient
|
|
12
|
+
from sqlsaber.clients.models import (
|
|
13
|
+
ContentBlock,
|
|
14
|
+
ContentType,
|
|
15
|
+
CreateMessageRequest,
|
|
16
|
+
Message,
|
|
17
|
+
MessageRole,
|
|
18
|
+
ToolDefinition,
|
|
19
|
+
)
|
|
14
20
|
from sqlsaber.config.settings import Config
|
|
15
21
|
from sqlsaber.database.connection import BaseDatabaseConnection
|
|
16
22
|
from sqlsaber.memory.manager import MemoryManager
|
|
17
23
|
from sqlsaber.models.events import StreamEvent
|
|
18
|
-
from sqlsaber.models.types import ToolDefinition
|
|
19
24
|
|
|
20
25
|
|
|
21
26
|
class AnthropicSQLAgent(BaseSQLAgent):
|
|
22
|
-
"""SQL Agent using Anthropic
|
|
27
|
+
"""SQL Agent using the custom Anthropic client."""
|
|
28
|
+
|
|
29
|
+
# Constants
|
|
30
|
+
MAX_TOKENS = 4096
|
|
31
|
+
DEFAULT_SQL_LIMIT = 100
|
|
23
32
|
|
|
24
33
|
def __init__(
|
|
25
34
|
self, db_connection: BaseDatabaseConnection, database_name: str | None = None
|
|
@@ -27,9 +36,12 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
27
36
|
super().__init__(db_connection)
|
|
28
37
|
|
|
29
38
|
config = Config()
|
|
30
|
-
config.validate() # This will raise ValueError if
|
|
39
|
+
config.validate() # This will raise ValueError if credentials are missing
|
|
31
40
|
|
|
32
|
-
|
|
41
|
+
if config.oauth_token:
|
|
42
|
+
self.client = AnthropicClient(oauth_token=config.oauth_token)
|
|
43
|
+
else:
|
|
44
|
+
self.client = AnthropicClient(api_key=config.api_key)
|
|
33
45
|
self.model = config.model_name.replace("anthropic:", "")
|
|
34
46
|
|
|
35
47
|
self.database_name = database_name
|
|
@@ -39,21 +51,21 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
39
51
|
self._last_results = None
|
|
40
52
|
self._last_query = None
|
|
41
53
|
|
|
42
|
-
# Define tools in
|
|
43
|
-
self.tools:
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
54
|
+
# Define tools in the new format
|
|
55
|
+
self.tools: list[ToolDefinition] = [
|
|
56
|
+
ToolDefinition(
|
|
57
|
+
name="list_tables",
|
|
58
|
+
description="Get a list of all tables in the database with row counts. Use this first to discover available tables.",
|
|
59
|
+
input_schema={
|
|
48
60
|
"type": "object",
|
|
49
61
|
"properties": {},
|
|
50
62
|
"required": [],
|
|
51
63
|
},
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
64
|
+
),
|
|
65
|
+
ToolDefinition(
|
|
66
|
+
name="introspect_schema",
|
|
67
|
+
description="Introspect database schema to understand table structures.",
|
|
68
|
+
input_schema={
|
|
57
69
|
"type": "object",
|
|
58
70
|
"properties": {
|
|
59
71
|
"table_pattern": {
|
|
@@ -63,11 +75,11 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
63
75
|
},
|
|
64
76
|
"required": [],
|
|
65
77
|
},
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
78
|
+
),
|
|
79
|
+
ToolDefinition(
|
|
80
|
+
name="execute_sql",
|
|
81
|
+
description="Execute a SQL query against the database.",
|
|
82
|
+
input_schema={
|
|
71
83
|
"type": "object",
|
|
72
84
|
"properties": {
|
|
73
85
|
"query": {
|
|
@@ -76,17 +88,17 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
76
88
|
},
|
|
77
89
|
"limit": {
|
|
78
90
|
"type": "integer",
|
|
79
|
-
"description": "Maximum number of rows to return (default:
|
|
80
|
-
"default":
|
|
91
|
+
"description": f"Maximum number of rows to return (default: {AnthropicSQLAgent.DEFAULT_SQL_LIMIT})",
|
|
92
|
+
"default": AnthropicSQLAgent.DEFAULT_SQL_LIMIT,
|
|
81
93
|
},
|
|
82
94
|
},
|
|
83
95
|
"required": ["query"],
|
|
84
96
|
},
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
97
|
+
),
|
|
98
|
+
ToolDefinition(
|
|
99
|
+
name="plot_data",
|
|
100
|
+
description="Create a plot of query results.",
|
|
101
|
+
input_schema={
|
|
90
102
|
"type": "object",
|
|
91
103
|
"properties": {
|
|
92
104
|
"y_values": {
|
|
@@ -120,7 +132,7 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
120
132
|
},
|
|
121
133
|
"required": ["y_values"],
|
|
122
134
|
},
|
|
123
|
-
|
|
135
|
+
),
|
|
124
136
|
]
|
|
125
137
|
|
|
126
138
|
# Build system prompt with memories if available
|
|
@@ -128,8 +140,24 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
128
140
|
|
|
129
141
|
def _build_system_prompt(self) -> str:
|
|
130
142
|
"""Build system prompt with optional memory context."""
|
|
143
|
+
# For OAuth authentication, start with Claude Code identity
|
|
144
|
+
# Check if we're using OAuth by looking at the client
|
|
145
|
+
is_oauth = (
|
|
146
|
+
hasattr(self, "client")
|
|
147
|
+
and hasattr(self.client, "use_oauth")
|
|
148
|
+
and self.client.use_oauth
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if is_oauth:
|
|
152
|
+
# For OAuth, keep system prompt minimal - just Claude Code identity
|
|
153
|
+
return "You are Claude Code, Anthropic's official CLI for Claude."
|
|
154
|
+
else:
|
|
155
|
+
return self._get_sql_assistant_instructions()
|
|
156
|
+
|
|
157
|
+
def _get_sql_assistant_instructions(self) -> str:
|
|
158
|
+
"""Get the detailed SQL assistant instructions."""
|
|
131
159
|
db_type = self._get_database_type_name()
|
|
132
|
-
|
|
160
|
+
instructions = f"""You are also a helpful SQL assistant that helps users query their {db_type} database.
|
|
133
161
|
|
|
134
162
|
Your responsibilities:
|
|
135
163
|
1. Understand user's natural language requests, think and convert them to SQL
|
|
@@ -161,9 +189,9 @@ Guidelines:
|
|
|
161
189
|
self.database_name
|
|
162
190
|
)
|
|
163
191
|
if memory_context.strip():
|
|
164
|
-
|
|
192
|
+
instructions += memory_context
|
|
165
193
|
|
|
166
|
-
return
|
|
194
|
+
return instructions
|
|
167
195
|
|
|
168
196
|
def add_memory(self, content: str) -> str | None:
|
|
169
197
|
"""Add a memory for the current database."""
|
|
@@ -197,83 +225,129 @@ Guidelines:
|
|
|
197
225
|
return result
|
|
198
226
|
|
|
199
227
|
async def process_tool_call(
|
|
200
|
-
self, tool_name: str, tool_input:
|
|
228
|
+
self, tool_name: str, tool_input: dict[str, Any]
|
|
201
229
|
) -> str:
|
|
202
230
|
"""Process a tool call and return the result."""
|
|
203
231
|
# Use parent implementation for core tools
|
|
204
232
|
return await super().process_tool_call(tool_name, tool_input)
|
|
205
233
|
|
|
206
|
-
|
|
207
|
-
self,
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
234
|
+
def _convert_user_message_to_message(
|
|
235
|
+
self, msg_content: str | list[dict[str, Any]]
|
|
236
|
+
) -> Message:
|
|
237
|
+
"""Convert user message content to Message object."""
|
|
238
|
+
if isinstance(msg_content, str):
|
|
239
|
+
return Message(MessageRole.USER, msg_content)
|
|
240
|
+
|
|
241
|
+
# Handle tool results format
|
|
242
|
+
tool_result_blocks = []
|
|
243
|
+
if isinstance(msg_content, list):
|
|
244
|
+
for item in msg_content:
|
|
245
|
+
if isinstance(item, dict) and item.get("type") == "tool_result":
|
|
246
|
+
tool_result_blocks.append(
|
|
247
|
+
ContentBlock(ContentType.TOOL_RESULT, item)
|
|
248
|
+
)
|
|
218
249
|
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
250
|
+
if tool_result_blocks:
|
|
251
|
+
return Message(MessageRole.USER, tool_result_blocks)
|
|
252
|
+
|
|
253
|
+
# Fallback to string representation
|
|
254
|
+
return Message(MessageRole.USER, str(msg_content))
|
|
255
|
+
|
|
256
|
+
def _convert_assistant_message_to_message(
|
|
257
|
+
self, msg_content: str | list[dict[str, Any]]
|
|
258
|
+
) -> Message:
|
|
259
|
+
"""Convert assistant message content to Message object."""
|
|
260
|
+
if isinstance(msg_content, str):
|
|
261
|
+
return Message(MessageRole.ASSISTANT, msg_content)
|
|
262
|
+
|
|
263
|
+
if isinstance(msg_content, list):
|
|
264
|
+
content_blocks = []
|
|
265
|
+
for block in msg_content:
|
|
266
|
+
if isinstance(block, dict):
|
|
267
|
+
if block.get("type") == "text":
|
|
268
|
+
text_content = block.get("text", "")
|
|
269
|
+
if text_content: # Only add non-empty text blocks
|
|
270
|
+
content_blocks.append(
|
|
271
|
+
ContentBlock(ContentType.TEXT, text_content)
|
|
272
|
+
)
|
|
273
|
+
elif block.get("type") == "tool_use":
|
|
274
|
+
content_blocks.append(
|
|
275
|
+
ContentBlock(
|
|
276
|
+
ContentType.TOOL_USE,
|
|
277
|
+
{
|
|
278
|
+
"id": block["id"],
|
|
279
|
+
"name": block["name"],
|
|
280
|
+
"input": block["input"],
|
|
281
|
+
},
|
|
282
|
+
)
|
|
225
283
|
)
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
284
|
+
if content_blocks:
|
|
285
|
+
return Message(MessageRole.ASSISTANT, content_blocks)
|
|
286
|
+
|
|
287
|
+
# Fallback to string representation
|
|
288
|
+
return Message(MessageRole.ASSISTANT, str(msg_content))
|
|
289
|
+
|
|
290
|
+
def _convert_history_to_messages(self) -> list[Message]:
|
|
291
|
+
"""Convert conversation history to Message objects."""
|
|
292
|
+
messages = []
|
|
293
|
+
for msg in self.conversation_history:
|
|
294
|
+
if msg["role"] == "user":
|
|
295
|
+
messages.append(self._convert_user_message_to_message(msg["content"]))
|
|
296
|
+
elif msg["role"] == "assistant":
|
|
297
|
+
messages.append(
|
|
298
|
+
self._convert_assistant_message_to_message(msg["content"])
|
|
299
|
+
)
|
|
300
|
+
return messages
|
|
301
|
+
|
|
302
|
+
def _convert_tool_results_to_message(
|
|
303
|
+
self, tool_results: list[dict[str, Any]]
|
|
304
|
+
) -> Message:
|
|
305
|
+
"""Convert tool results to a user Message object."""
|
|
306
|
+
tool_result_blocks = []
|
|
307
|
+
for tool_result in tool_results:
|
|
308
|
+
tool_result_blocks.append(
|
|
309
|
+
ContentBlock(ContentType.TOOL_RESULT, tool_result)
|
|
310
|
+
)
|
|
311
|
+
return Message(MessageRole.USER, tool_result_blocks)
|
|
312
|
+
|
|
313
|
+
def _convert_response_content_to_message(
|
|
314
|
+
self, content: list[dict[str, Any]]
|
|
315
|
+
) -> Message:
|
|
316
|
+
"""Convert response content to assistant Message object."""
|
|
317
|
+
content_blocks = []
|
|
318
|
+
for block in content:
|
|
319
|
+
if block.get("type") == "text":
|
|
320
|
+
text_content = block["text"]
|
|
321
|
+
if text_content: # Only add non-empty text blocks
|
|
322
|
+
content_blocks.append(ContentBlock(ContentType.TEXT, text_content))
|
|
323
|
+
elif block.get("type") == "tool_use":
|
|
324
|
+
content_blocks.append(
|
|
325
|
+
ContentBlock(
|
|
326
|
+
ContentType.TOOL_USE,
|
|
327
|
+
{
|
|
328
|
+
"id": block["id"],
|
|
329
|
+
"name": block["name"],
|
|
330
|
+
"input": block["input"],
|
|
331
|
+
},
|
|
332
|
+
)
|
|
333
|
+
)
|
|
334
|
+
return Message(MessageRole.ASSISTANT, content_blocks)
|
|
335
|
+
|
|
336
|
+
async def _execute_and_yield_tool_results(
|
|
265
337
|
self,
|
|
266
|
-
|
|
338
|
+
response_content: list[dict[str, Any]],
|
|
267
339
|
cancellation_token: asyncio.Event | None = None,
|
|
268
|
-
) -> AsyncIterator[StreamEvent]:
|
|
269
|
-
"""
|
|
340
|
+
) -> AsyncIterator[StreamEvent | list[dict[str, Any]]]:
|
|
341
|
+
"""Execute tool calls and yield appropriate stream events."""
|
|
270
342
|
tool_results = []
|
|
271
|
-
for block in response.content:
|
|
272
|
-
# Only check cancellation if token is provided
|
|
273
|
-
if cancellation_token is not None and cancellation_token.is_set():
|
|
274
|
-
return
|
|
275
343
|
|
|
344
|
+
for block in response_content:
|
|
276
345
|
if block.get("type") == "tool_use":
|
|
346
|
+
# Check for cancellation before tool execution
|
|
347
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
348
|
+
yield tool_results
|
|
349
|
+
return
|
|
350
|
+
|
|
277
351
|
yield StreamEvent(
|
|
278
352
|
"tool_use",
|
|
279
353
|
{
|
|
@@ -316,7 +390,53 @@ Guidelines:
|
|
|
316
390
|
|
|
317
391
|
tool_results.append(build_tool_result_block(block["id"], tool_result))
|
|
318
392
|
|
|
319
|
-
yield
|
|
393
|
+
yield tool_results
|
|
394
|
+
|
|
395
|
+
async def _handle_stream_events(
|
|
396
|
+
self,
|
|
397
|
+
stream_iterator: AsyncIterator[Any],
|
|
398
|
+
cancellation_token: asyncio.Event | None = None,
|
|
399
|
+
) -> AsyncIterator[StreamEvent | Any]:
|
|
400
|
+
"""Handle streaming events and yield stream events, return final response."""
|
|
401
|
+
response = None
|
|
402
|
+
|
|
403
|
+
async for event in stream_iterator:
|
|
404
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
405
|
+
yield None
|
|
406
|
+
return
|
|
407
|
+
|
|
408
|
+
# Handle different event types
|
|
409
|
+
if hasattr(event, "type"):
|
|
410
|
+
if event.type == "content_block_start":
|
|
411
|
+
if hasattr(event.content_block, "type"):
|
|
412
|
+
if event.content_block.type == "tool_use":
|
|
413
|
+
yield StreamEvent(
|
|
414
|
+
"tool_use",
|
|
415
|
+
{
|
|
416
|
+
"name": event.content_block.name,
|
|
417
|
+
"status": "started",
|
|
418
|
+
},
|
|
419
|
+
)
|
|
420
|
+
elif event.type == "content_block_delta":
|
|
421
|
+
if hasattr(event.delta, "text"):
|
|
422
|
+
text = event.delta.text
|
|
423
|
+
if text is not None and text: # Only yield non-empty text
|
|
424
|
+
yield StreamEvent("text", text)
|
|
425
|
+
elif isinstance(event, dict) and event.get("type") == "response_ready":
|
|
426
|
+
response = event["data"]
|
|
427
|
+
|
|
428
|
+
yield response
|
|
429
|
+
|
|
430
|
+
def _create_message_request(self, messages: list[Message]) -> CreateMessageRequest:
|
|
431
|
+
"""Create a CreateMessageRequest with standard parameters."""
|
|
432
|
+
return CreateMessageRequest(
|
|
433
|
+
model=self.model,
|
|
434
|
+
messages=messages,
|
|
435
|
+
max_tokens=self.MAX_TOKENS,
|
|
436
|
+
system=self.system_prompt,
|
|
437
|
+
tools=self.tools,
|
|
438
|
+
stream=True,
|
|
439
|
+
)
|
|
320
440
|
|
|
321
441
|
async def query_stream(
|
|
322
442
|
self,
|
|
@@ -329,32 +449,37 @@ Guidelines:
|
|
|
329
449
|
self._last_results = None
|
|
330
450
|
self._last_query = None
|
|
331
451
|
|
|
332
|
-
# Build messages with history if requested
|
|
333
|
-
if use_history:
|
|
334
|
-
messages = self.conversation_history + [
|
|
335
|
-
{"role": "user", "content": user_query}
|
|
336
|
-
]
|
|
337
|
-
else:
|
|
338
|
-
messages = [{"role": "user", "content": user_query}]
|
|
339
|
-
|
|
340
452
|
try:
|
|
341
|
-
#
|
|
453
|
+
# Build messages with history if requested
|
|
454
|
+
messages = []
|
|
455
|
+
if use_history:
|
|
456
|
+
messages = self._convert_history_to_messages()
|
|
457
|
+
|
|
458
|
+
# For OAuth with no history, inject SQL assistant instructions as first user message
|
|
459
|
+
is_oauth = hasattr(self.client, "use_oauth") and self.client.use_oauth
|
|
460
|
+
if is_oauth and not messages:
|
|
461
|
+
instructions = self._get_sql_assistant_instructions()
|
|
462
|
+
messages.append(Message(MessageRole.USER, instructions))
|
|
463
|
+
|
|
464
|
+
# Add current user message
|
|
465
|
+
messages.append(Message(MessageRole.USER, user_query))
|
|
466
|
+
|
|
467
|
+
# Create initial request and get response
|
|
468
|
+
request = self._create_message_request(messages)
|
|
342
469
|
response = None
|
|
343
|
-
|
|
344
|
-
|
|
470
|
+
|
|
471
|
+
async for event in self._handle_stream_events(
|
|
472
|
+
self.client.create_message_with_tools(request, cancellation_token),
|
|
473
|
+
cancellation_token,
|
|
345
474
|
):
|
|
346
|
-
if
|
|
347
|
-
return
|
|
348
|
-
if event.type == "response_ready":
|
|
349
|
-
response = event.data
|
|
350
|
-
else:
|
|
475
|
+
if isinstance(event, StreamEvent):
|
|
351
476
|
yield event
|
|
477
|
+
else:
|
|
478
|
+
response = event
|
|
352
479
|
|
|
480
|
+
# Handle tool use cycles
|
|
353
481
|
collected_content = []
|
|
354
|
-
|
|
355
|
-
# Process tool calls if needed
|
|
356
482
|
while response is not None and response.stop_reason == "tool_use":
|
|
357
|
-
# Check for cancellation at the start of tool cycle
|
|
358
483
|
if cancellation_token is not None and cancellation_token.is_set():
|
|
359
484
|
return
|
|
360
485
|
|
|
@@ -363,82 +488,64 @@ Guidelines:
|
|
|
363
488
|
{"role": "assistant", "content": response.content}
|
|
364
489
|
)
|
|
365
490
|
|
|
366
|
-
#
|
|
367
|
-
# as this would break the tool_use -> tool_result API contract
|
|
491
|
+
# Execute tools and get results
|
|
368
492
|
tool_results = []
|
|
369
|
-
async for event in self.
|
|
370
|
-
response,
|
|
371
|
-
):
|
|
372
|
-
if event
|
|
373
|
-
tool_results = event.data
|
|
374
|
-
else:
|
|
493
|
+
async for event in self._execute_and_yield_tool_results(
|
|
494
|
+
response.content, cancellation_token
|
|
495
|
+
):
|
|
496
|
+
if isinstance(event, StreamEvent):
|
|
375
497
|
yield event
|
|
498
|
+
elif isinstance(event, list):
|
|
499
|
+
tool_results = event
|
|
376
500
|
|
|
377
501
|
# Continue conversation with tool results
|
|
378
502
|
collected_content.append({"role": "user", "content": tool_results})
|
|
379
503
|
if use_history:
|
|
380
504
|
self.conversation_history.extend(collected_content)
|
|
381
505
|
|
|
382
|
-
# Check for cancellation AFTER tool results are complete
|
|
383
506
|
if cancellation_token is not None and cancellation_token.is_set():
|
|
384
507
|
return
|
|
385
508
|
|
|
386
|
-
# Signal that we're processing the tool results
|
|
387
509
|
yield StreamEvent("processing", "Analyzing results...")
|
|
388
510
|
|
|
511
|
+
# Build new messages with collected content
|
|
512
|
+
new_messages = messages.copy()
|
|
513
|
+
for content in collected_content:
|
|
514
|
+
if content["role"] == "user":
|
|
515
|
+
new_messages.append(
|
|
516
|
+
self._convert_tool_results_to_message(content["content"])
|
|
517
|
+
)
|
|
518
|
+
elif content["role"] == "assistant":
|
|
519
|
+
new_messages.append(
|
|
520
|
+
self._convert_response_content_to_message(
|
|
521
|
+
content["content"]
|
|
522
|
+
)
|
|
523
|
+
)
|
|
524
|
+
|
|
389
525
|
# Get next response
|
|
526
|
+
request = self._create_message_request(new_messages)
|
|
390
527
|
response = None
|
|
391
|
-
|
|
392
|
-
|
|
528
|
+
|
|
529
|
+
async for event in self._handle_stream_events(
|
|
530
|
+
self.client.create_message_with_tools(request, cancellation_token),
|
|
531
|
+
cancellation_token,
|
|
393
532
|
):
|
|
394
|
-
if
|
|
395
|
-
return
|
|
396
|
-
if event.type == "response_ready":
|
|
397
|
-
response = event.data
|
|
398
|
-
else:
|
|
533
|
+
if isinstance(event, StreamEvent):
|
|
399
534
|
yield event
|
|
535
|
+
else:
|
|
536
|
+
response = event
|
|
400
537
|
|
|
401
|
-
# Update conversation history
|
|
402
|
-
if use_history:
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
{"role": "assistant", "content": response.content}
|
|
407
|
-
)
|
|
538
|
+
# Update conversation history with final response
|
|
539
|
+
if use_history and response is not None:
|
|
540
|
+
self.conversation_history.append(
|
|
541
|
+
{"role": "assistant", "content": response.content}
|
|
542
|
+
)
|
|
408
543
|
|
|
409
544
|
except asyncio.CancelledError:
|
|
410
545
|
return
|
|
411
546
|
except Exception as e:
|
|
412
547
|
yield StreamEvent("error", str(e))
|
|
413
548
|
|
|
414
|
-
async def
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
"""Create a stream and yield events while building response."""
|
|
418
|
-
stream = await self.client.messages.create(
|
|
419
|
-
model=self.model,
|
|
420
|
-
max_tokens=4096,
|
|
421
|
-
system=self.system_prompt,
|
|
422
|
-
messages=messages,
|
|
423
|
-
tools=self.tools,
|
|
424
|
-
stream=True,
|
|
425
|
-
)
|
|
426
|
-
|
|
427
|
-
content_blocks = []
|
|
428
|
-
tool_use_blocks = []
|
|
429
|
-
|
|
430
|
-
async for event in self._process_stream_events(
|
|
431
|
-
stream, content_blocks, tool_use_blocks, cancellation_token
|
|
432
|
-
):
|
|
433
|
-
# Only check cancellation if token is provided
|
|
434
|
-
if cancellation_token is not None and cancellation_token.is_set():
|
|
435
|
-
return
|
|
436
|
-
yield event
|
|
437
|
-
|
|
438
|
-
# Finalize tool blocks and create response
|
|
439
|
-
stop_reason = self._finalize_tool_blocks(tool_use_blocks)
|
|
440
|
-
content_blocks.extend(tool_use_blocks)
|
|
441
|
-
|
|
442
|
-
yield StreamEvent(
|
|
443
|
-
"response_ready", StreamingResponse(content_blocks, stop_reason)
|
|
444
|
-
)
|
|
549
|
+
async def close(self):
|
|
550
|
+
"""Close the client."""
|
|
551
|
+
await self.client.close()
|