sqlsaber 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -4
- sqlsaber/agents/base.py +18 -221
- sqlsaber/agents/mcp.py +2 -2
- sqlsaber/agents/pydantic_ai_agent.py +170 -0
- sqlsaber/cli/auth.py +146 -79
- sqlsaber/cli/commands.py +22 -7
- sqlsaber/cli/database.py +1 -1
- sqlsaber/cli/interactive.py +65 -30
- sqlsaber/cli/models.py +58 -29
- sqlsaber/cli/streaming.py +114 -77
- sqlsaber/config/api_keys.py +9 -11
- sqlsaber/config/providers.py +116 -0
- sqlsaber/config/settings.py +50 -30
- sqlsaber/database/connection.py +3 -3
- sqlsaber/mcp/mcp.py +43 -51
- sqlsaber/models/__init__.py +0 -3
- sqlsaber/tools/__init__.py +25 -0
- sqlsaber/tools/base.py +85 -0
- sqlsaber/tools/enums.py +21 -0
- sqlsaber/tools/instructions.py +251 -0
- sqlsaber/tools/registry.py +130 -0
- sqlsaber/tools/sql_tools.py +275 -0
- sqlsaber/tools/visualization_tools.py +144 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/METADATA +20 -39
- sqlsaber-0.16.0.dist-info/RECORD +51 -0
- sqlsaber/agents/anthropic.py +0 -579
- sqlsaber/agents/streaming.py +0 -16
- sqlsaber/clients/__init__.py +0 -6
- sqlsaber/clients/anthropic.py +0 -285
- sqlsaber/clients/base.py +0 -31
- sqlsaber/clients/exceptions.py +0 -117
- sqlsaber/clients/models.py +0 -282
- sqlsaber/clients/streaming.py +0 -257
- sqlsaber/models/events.py +0 -28
- sqlsaber-0.14.0.dist-info/RECORD +0 -51
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/anthropic.py
DELETED
|
@@ -1,579 +0,0 @@
|
|
|
1
|
-
"""Anthropic-specific SQL agent implementation using the custom client."""
|
|
2
|
-
|
|
3
|
-
import asyncio
|
|
4
|
-
import json
|
|
5
|
-
from typing import Any, AsyncIterator
|
|
6
|
-
|
|
7
|
-
from sqlsaber.agents.base import BaseSQLAgent
|
|
8
|
-
from sqlsaber.agents.streaming import (
|
|
9
|
-
build_tool_result_block,
|
|
10
|
-
)
|
|
11
|
-
from sqlsaber.clients import AnthropicClient
|
|
12
|
-
from sqlsaber.clients.models import (
|
|
13
|
-
ContentBlock,
|
|
14
|
-
ContentType,
|
|
15
|
-
CreateMessageRequest,
|
|
16
|
-
Message,
|
|
17
|
-
MessageRole,
|
|
18
|
-
ToolDefinition,
|
|
19
|
-
)
|
|
20
|
-
from sqlsaber.config.settings import Config
|
|
21
|
-
from sqlsaber.database.connection import BaseDatabaseConnection
|
|
22
|
-
from sqlsaber.memory.manager import MemoryManager
|
|
23
|
-
from sqlsaber.models.events import StreamEvent
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class AnthropicSQLAgent(BaseSQLAgent):
|
|
27
|
-
"""SQL Agent using the custom Anthropic client."""
|
|
28
|
-
|
|
29
|
-
# Constants
|
|
30
|
-
MAX_TOKENS = 4096
|
|
31
|
-
DEFAULT_SQL_LIMIT = 100
|
|
32
|
-
|
|
33
|
-
def __init__(
|
|
34
|
-
self, db_connection: BaseDatabaseConnection, database_name: str | None = None
|
|
35
|
-
):
|
|
36
|
-
super().__init__(db_connection)
|
|
37
|
-
|
|
38
|
-
config = Config()
|
|
39
|
-
config.validate() # This will raise ValueError if credentials are missing
|
|
40
|
-
|
|
41
|
-
if config.oauth_token:
|
|
42
|
-
self.client = AnthropicClient(oauth_token=config.oauth_token)
|
|
43
|
-
else:
|
|
44
|
-
self.client = AnthropicClient(api_key=config.api_key)
|
|
45
|
-
self.model = config.model_name.replace("anthropic:", "")
|
|
46
|
-
|
|
47
|
-
self.database_name = database_name
|
|
48
|
-
self.memory_manager = MemoryManager()
|
|
49
|
-
|
|
50
|
-
# Track last query results for streaming
|
|
51
|
-
self._last_results = None
|
|
52
|
-
self._last_query = None
|
|
53
|
-
|
|
54
|
-
# Define tools in the new format
|
|
55
|
-
self.tools: list[ToolDefinition] = [
|
|
56
|
-
ToolDefinition(
|
|
57
|
-
name="list_tables",
|
|
58
|
-
description="Get a list of all tables in the database with row counts. Use this first to discover available tables.",
|
|
59
|
-
input_schema={
|
|
60
|
-
"type": "object",
|
|
61
|
-
"properties": {},
|
|
62
|
-
"required": [],
|
|
63
|
-
},
|
|
64
|
-
),
|
|
65
|
-
ToolDefinition(
|
|
66
|
-
name="introspect_schema",
|
|
67
|
-
description="Introspect database schema to understand table structures.",
|
|
68
|
-
input_schema={
|
|
69
|
-
"type": "object",
|
|
70
|
-
"properties": {
|
|
71
|
-
"table_pattern": {
|
|
72
|
-
"type": "string",
|
|
73
|
-
"description": "Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')",
|
|
74
|
-
}
|
|
75
|
-
},
|
|
76
|
-
"required": [],
|
|
77
|
-
},
|
|
78
|
-
),
|
|
79
|
-
ToolDefinition(
|
|
80
|
-
name="execute_sql",
|
|
81
|
-
description="Execute a SQL query against the database.",
|
|
82
|
-
input_schema={
|
|
83
|
-
"type": "object",
|
|
84
|
-
"properties": {
|
|
85
|
-
"query": {
|
|
86
|
-
"type": "string",
|
|
87
|
-
"description": "SQL query to execute",
|
|
88
|
-
},
|
|
89
|
-
"limit": {
|
|
90
|
-
"type": "integer",
|
|
91
|
-
"description": f"Maximum number of rows to return (default: {AnthropicSQLAgent.DEFAULT_SQL_LIMIT})",
|
|
92
|
-
"default": AnthropicSQLAgent.DEFAULT_SQL_LIMIT,
|
|
93
|
-
},
|
|
94
|
-
},
|
|
95
|
-
"required": ["query"],
|
|
96
|
-
},
|
|
97
|
-
),
|
|
98
|
-
ToolDefinition(
|
|
99
|
-
name="plot_data",
|
|
100
|
-
description="Create a plot of query results.",
|
|
101
|
-
input_schema={
|
|
102
|
-
"type": "object",
|
|
103
|
-
"properties": {
|
|
104
|
-
"y_values": {
|
|
105
|
-
"type": "array",
|
|
106
|
-
"items": {"type": ["number", "null"]},
|
|
107
|
-
"description": "Y-axis data points (required)",
|
|
108
|
-
},
|
|
109
|
-
"x_values": {
|
|
110
|
-
"type": "array",
|
|
111
|
-
"items": {"type": ["number", "null"]},
|
|
112
|
-
"description": "X-axis data points (optional, will use indices if not provided)",
|
|
113
|
-
},
|
|
114
|
-
"plot_type": {
|
|
115
|
-
"type": "string",
|
|
116
|
-
"enum": ["line", "scatter", "histogram"],
|
|
117
|
-
"description": "Type of plot to create (default: line)",
|
|
118
|
-
"default": "line",
|
|
119
|
-
},
|
|
120
|
-
"title": {
|
|
121
|
-
"type": "string",
|
|
122
|
-
"description": "Title for the plot",
|
|
123
|
-
},
|
|
124
|
-
"x_label": {
|
|
125
|
-
"type": "string",
|
|
126
|
-
"description": "Label for X-axis",
|
|
127
|
-
},
|
|
128
|
-
"y_label": {
|
|
129
|
-
"type": "string",
|
|
130
|
-
"description": "Label for Y-axis",
|
|
131
|
-
},
|
|
132
|
-
},
|
|
133
|
-
"required": ["y_values"],
|
|
134
|
-
},
|
|
135
|
-
),
|
|
136
|
-
]
|
|
137
|
-
|
|
138
|
-
# Build system prompt with memories if available
|
|
139
|
-
self.system_prompt = self._build_system_prompt()
|
|
140
|
-
|
|
141
|
-
def _build_system_prompt(self) -> str:
|
|
142
|
-
"""Build system prompt with optional memory context."""
|
|
143
|
-
# For OAuth authentication, start with Claude Code identity
|
|
144
|
-
# Check if we're using OAuth by looking at the client
|
|
145
|
-
is_oauth = (
|
|
146
|
-
hasattr(self, "client")
|
|
147
|
-
and hasattr(self.client, "use_oauth")
|
|
148
|
-
and self.client.use_oauth
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
if is_oauth:
|
|
152
|
-
# For OAuth, keep system prompt minimal - just Claude Code identity
|
|
153
|
-
return "You are Claude Code, Anthropic's official CLI for Claude."
|
|
154
|
-
else:
|
|
155
|
-
return self._get_sql_assistant_instructions()
|
|
156
|
-
|
|
157
|
-
def _get_sql_assistant_instructions(self) -> str:
|
|
158
|
-
"""Get the detailed SQL assistant instructions."""
|
|
159
|
-
db_type = self._get_database_type_name()
|
|
160
|
-
instructions = f"""You are also a helpful SQL assistant that helps users query their {db_type} database.
|
|
161
|
-
|
|
162
|
-
Your responsibilities:
|
|
163
|
-
1. Understand user's natural language requests, think and convert them to SQL
|
|
164
|
-
2. Use the provided tools efficiently to explore database schema
|
|
165
|
-
3. Generate appropriate SQL queries
|
|
166
|
-
4. Execute queries safely - queries that modify the database are not allowed
|
|
167
|
-
5. Format and explain results clearly
|
|
168
|
-
6. Create visualizations when requested or when they would be helpful
|
|
169
|
-
|
|
170
|
-
IMPORTANT - Schema Discovery Strategy:
|
|
171
|
-
1. ALWAYS start with 'list_tables' to see available tables and row counts
|
|
172
|
-
2. Based on the user's query, identify which specific tables are relevant
|
|
173
|
-
3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
|
|
174
|
-
4. Timestamp columns must be converted to text when you write queries
|
|
175
|
-
|
|
176
|
-
Guidelines:
|
|
177
|
-
- Use list_tables first, then introspect_schema for specific tables only
|
|
178
|
-
- Use table patterns like 'sample%' or '%experiment%' to filter related tables
|
|
179
|
-
- Use proper JOIN syntax and avoid cartesian products
|
|
180
|
-
- Include appropriate WHERE clauses to limit results
|
|
181
|
-
- Explain what the query does in simple terms
|
|
182
|
-
- Handle errors gracefully and suggest fixes
|
|
183
|
-
- Be security conscious - use parameterized queries when needed
|
|
184
|
-
"""
|
|
185
|
-
|
|
186
|
-
# Add memory context if database name is available
|
|
187
|
-
if self.database_name:
|
|
188
|
-
memory_context = self.memory_manager.format_memories_for_prompt(
|
|
189
|
-
self.database_name
|
|
190
|
-
)
|
|
191
|
-
if memory_context.strip():
|
|
192
|
-
instructions += memory_context
|
|
193
|
-
|
|
194
|
-
return instructions
|
|
195
|
-
|
|
196
|
-
def add_memory(self, content: str) -> str | None:
|
|
197
|
-
"""Add a memory for the current database."""
|
|
198
|
-
if not self.database_name:
|
|
199
|
-
return None
|
|
200
|
-
|
|
201
|
-
memory = self.memory_manager.add_memory(self.database_name, content)
|
|
202
|
-
# Rebuild system prompt with new memory
|
|
203
|
-
self.system_prompt = self._build_system_prompt()
|
|
204
|
-
return memory.id
|
|
205
|
-
|
|
206
|
-
async def execute_sql(self, query: str, limit: int | None = None) -> str:
|
|
207
|
-
"""Execute a SQL query against the database with streaming support."""
|
|
208
|
-
# Call parent implementation for core functionality
|
|
209
|
-
result = await super().execute_sql(query, limit)
|
|
210
|
-
|
|
211
|
-
# Parse result to extract data for streaming (AnthropicSQLAgent specific)
|
|
212
|
-
try:
|
|
213
|
-
result_data = json.loads(result)
|
|
214
|
-
if result_data.get("success") and "results" in result_data:
|
|
215
|
-
# Store results for streaming
|
|
216
|
-
actual_limit = (
|
|
217
|
-
limit if limit is not None else len(result_data["results"])
|
|
218
|
-
)
|
|
219
|
-
self._last_results = result_data["results"][:actual_limit]
|
|
220
|
-
self._last_query = query
|
|
221
|
-
except (json.JSONDecodeError, KeyError):
|
|
222
|
-
# If we can't parse the result, just continue without storing
|
|
223
|
-
pass
|
|
224
|
-
|
|
225
|
-
return result
|
|
226
|
-
|
|
227
|
-
async def process_tool_call(
|
|
228
|
-
self, tool_name: str, tool_input: dict[str, Any]
|
|
229
|
-
) -> str:
|
|
230
|
-
"""Process a tool call and return the result."""
|
|
231
|
-
# Use parent implementation for core tools
|
|
232
|
-
return await super().process_tool_call(tool_name, tool_input)
|
|
233
|
-
|
|
234
|
-
def _convert_user_message_to_message(
|
|
235
|
-
self, msg_content: str | list[dict[str, Any]]
|
|
236
|
-
) -> Message:
|
|
237
|
-
"""Convert user message content to Message object."""
|
|
238
|
-
if isinstance(msg_content, str):
|
|
239
|
-
return Message(MessageRole.USER, msg_content)
|
|
240
|
-
|
|
241
|
-
# Handle tool results format
|
|
242
|
-
tool_result_blocks = []
|
|
243
|
-
if isinstance(msg_content, list):
|
|
244
|
-
for item in msg_content:
|
|
245
|
-
if isinstance(item, dict) and item.get("type") == "tool_result":
|
|
246
|
-
tool_result_blocks.append(
|
|
247
|
-
ContentBlock(ContentType.TOOL_RESULT, item)
|
|
248
|
-
)
|
|
249
|
-
|
|
250
|
-
if tool_result_blocks:
|
|
251
|
-
return Message(MessageRole.USER, tool_result_blocks)
|
|
252
|
-
|
|
253
|
-
# Fallback to string representation
|
|
254
|
-
return Message(MessageRole.USER, str(msg_content))
|
|
255
|
-
|
|
256
|
-
def _convert_assistant_message_to_message(
|
|
257
|
-
self, msg_content: str | list[dict[str, Any]]
|
|
258
|
-
) -> Message:
|
|
259
|
-
"""Convert assistant message content to Message object."""
|
|
260
|
-
if isinstance(msg_content, str):
|
|
261
|
-
return Message(MessageRole.ASSISTANT, msg_content)
|
|
262
|
-
|
|
263
|
-
if isinstance(msg_content, list):
|
|
264
|
-
content_blocks = []
|
|
265
|
-
for block in msg_content:
|
|
266
|
-
if isinstance(block, dict):
|
|
267
|
-
if block.get("type") == "text":
|
|
268
|
-
text_content = block.get("text", "")
|
|
269
|
-
if text_content: # Only add non-empty text blocks
|
|
270
|
-
content_blocks.append(
|
|
271
|
-
ContentBlock(ContentType.TEXT, text_content)
|
|
272
|
-
)
|
|
273
|
-
elif block.get("type") == "tool_use":
|
|
274
|
-
content_blocks.append(
|
|
275
|
-
ContentBlock(
|
|
276
|
-
ContentType.TOOL_USE,
|
|
277
|
-
{
|
|
278
|
-
"id": block["id"],
|
|
279
|
-
"name": block["name"],
|
|
280
|
-
"input": block["input"],
|
|
281
|
-
},
|
|
282
|
-
)
|
|
283
|
-
)
|
|
284
|
-
if content_blocks:
|
|
285
|
-
return Message(MessageRole.ASSISTANT, content_blocks)
|
|
286
|
-
|
|
287
|
-
# Fallback to string representation
|
|
288
|
-
return Message(MessageRole.ASSISTANT, str(msg_content))
|
|
289
|
-
|
|
290
|
-
def _convert_history_to_messages(self) -> list[Message]:
|
|
291
|
-
"""Convert conversation history to Message objects."""
|
|
292
|
-
messages = []
|
|
293
|
-
for msg in self.conversation_history:
|
|
294
|
-
if msg["role"] == "user":
|
|
295
|
-
messages.append(self._convert_user_message_to_message(msg["content"]))
|
|
296
|
-
elif msg["role"] == "assistant":
|
|
297
|
-
messages.append(
|
|
298
|
-
self._convert_assistant_message_to_message(msg["content"])
|
|
299
|
-
)
|
|
300
|
-
return messages
|
|
301
|
-
|
|
302
|
-
def _convert_tool_results_to_message(
|
|
303
|
-
self, tool_results: list[dict[str, Any]]
|
|
304
|
-
) -> Message:
|
|
305
|
-
"""Convert tool results to a user Message object."""
|
|
306
|
-
tool_result_blocks = []
|
|
307
|
-
for tool_result in tool_results:
|
|
308
|
-
tool_result_blocks.append(
|
|
309
|
-
ContentBlock(ContentType.TOOL_RESULT, tool_result)
|
|
310
|
-
)
|
|
311
|
-
return Message(MessageRole.USER, tool_result_blocks)
|
|
312
|
-
|
|
313
|
-
def _convert_response_content_to_message(
|
|
314
|
-
self, content: list[dict[str, Any]]
|
|
315
|
-
) -> Message:
|
|
316
|
-
"""Convert response content to assistant Message object."""
|
|
317
|
-
content_blocks = []
|
|
318
|
-
for block in content:
|
|
319
|
-
if block.get("type") == "text":
|
|
320
|
-
text_content = block["text"]
|
|
321
|
-
if text_content: # Only add non-empty text blocks
|
|
322
|
-
content_blocks.append(ContentBlock(ContentType.TEXT, text_content))
|
|
323
|
-
elif block.get("type") == "tool_use":
|
|
324
|
-
content_blocks.append(
|
|
325
|
-
ContentBlock(
|
|
326
|
-
ContentType.TOOL_USE,
|
|
327
|
-
{
|
|
328
|
-
"id": block["id"],
|
|
329
|
-
"name": block["name"],
|
|
330
|
-
"input": block["input"],
|
|
331
|
-
},
|
|
332
|
-
)
|
|
333
|
-
)
|
|
334
|
-
return Message(MessageRole.ASSISTANT, content_blocks)
|
|
335
|
-
|
|
336
|
-
async def _execute_and_yield_tool_results(
|
|
337
|
-
self,
|
|
338
|
-
response_content: list[dict[str, Any]],
|
|
339
|
-
cancellation_token: asyncio.Event | None = None,
|
|
340
|
-
) -> AsyncIterator[StreamEvent | list[dict[str, Any]]]:
|
|
341
|
-
"""Execute tool calls and yield appropriate stream events."""
|
|
342
|
-
tool_results = []
|
|
343
|
-
|
|
344
|
-
for block in response_content:
|
|
345
|
-
if block.get("type") == "tool_use":
|
|
346
|
-
# Check for cancellation before tool execution
|
|
347
|
-
if cancellation_token is not None and cancellation_token.is_set():
|
|
348
|
-
yield tool_results
|
|
349
|
-
return
|
|
350
|
-
|
|
351
|
-
yield StreamEvent(
|
|
352
|
-
"tool_use",
|
|
353
|
-
{
|
|
354
|
-
"name": block["name"],
|
|
355
|
-
"input": block["input"],
|
|
356
|
-
"status": "executing",
|
|
357
|
-
},
|
|
358
|
-
)
|
|
359
|
-
|
|
360
|
-
tool_result = await self.process_tool_call(
|
|
361
|
-
block["name"], block["input"]
|
|
362
|
-
)
|
|
363
|
-
|
|
364
|
-
# Yield specific events based on tool type
|
|
365
|
-
if block["name"] == "execute_sql" and self._last_results:
|
|
366
|
-
yield StreamEvent(
|
|
367
|
-
"query_result",
|
|
368
|
-
{
|
|
369
|
-
"query": self._last_query,
|
|
370
|
-
"results": self._last_results,
|
|
371
|
-
},
|
|
372
|
-
)
|
|
373
|
-
elif block["name"] in ["list_tables", "introspect_schema"]:
|
|
374
|
-
yield StreamEvent(
|
|
375
|
-
"tool_result",
|
|
376
|
-
{
|
|
377
|
-
"tool_name": block["name"],
|
|
378
|
-
"result": tool_result,
|
|
379
|
-
},
|
|
380
|
-
)
|
|
381
|
-
elif block["name"] == "plot_data":
|
|
382
|
-
yield StreamEvent(
|
|
383
|
-
"plot_result",
|
|
384
|
-
{
|
|
385
|
-
"tool_name": block["name"],
|
|
386
|
-
"input": block["input"],
|
|
387
|
-
"result": tool_result,
|
|
388
|
-
},
|
|
389
|
-
)
|
|
390
|
-
|
|
391
|
-
tool_results.append(build_tool_result_block(block["id"], tool_result))
|
|
392
|
-
|
|
393
|
-
yield tool_results
|
|
394
|
-
|
|
395
|
-
async def _handle_stream_events(
|
|
396
|
-
self,
|
|
397
|
-
stream_iterator: AsyncIterator[Any],
|
|
398
|
-
cancellation_token: asyncio.Event | None = None,
|
|
399
|
-
) -> AsyncIterator[StreamEvent | Any]:
|
|
400
|
-
"""Handle streaming events and yield stream events, return final response."""
|
|
401
|
-
response = None
|
|
402
|
-
|
|
403
|
-
async for event in stream_iterator:
|
|
404
|
-
if cancellation_token is not None and cancellation_token.is_set():
|
|
405
|
-
yield None
|
|
406
|
-
return
|
|
407
|
-
|
|
408
|
-
# Handle different event types
|
|
409
|
-
if hasattr(event, "type"):
|
|
410
|
-
if event.type == "content_block_start":
|
|
411
|
-
if hasattr(event.content_block, "type"):
|
|
412
|
-
if event.content_block.type == "tool_use":
|
|
413
|
-
yield StreamEvent(
|
|
414
|
-
"tool_use",
|
|
415
|
-
{
|
|
416
|
-
"name": event.content_block.name,
|
|
417
|
-
"status": "started",
|
|
418
|
-
},
|
|
419
|
-
)
|
|
420
|
-
elif event.type == "content_block_delta":
|
|
421
|
-
if hasattr(event.delta, "text"):
|
|
422
|
-
text = event.delta.text
|
|
423
|
-
if text is not None and text: # Only yield non-empty text
|
|
424
|
-
yield StreamEvent("text", text)
|
|
425
|
-
elif isinstance(event, dict) and event.get("type") == "response_ready":
|
|
426
|
-
response = event["data"]
|
|
427
|
-
|
|
428
|
-
yield response
|
|
429
|
-
|
|
430
|
-
def _create_message_request(self, messages: list[Message]) -> CreateMessageRequest:
|
|
431
|
-
"""Create a CreateMessageRequest with standard parameters."""
|
|
432
|
-
return CreateMessageRequest(
|
|
433
|
-
model=self.model,
|
|
434
|
-
messages=messages,
|
|
435
|
-
max_tokens=self.MAX_TOKENS,
|
|
436
|
-
system=self.system_prompt,
|
|
437
|
-
tools=self.tools,
|
|
438
|
-
stream=True,
|
|
439
|
-
)
|
|
440
|
-
|
|
441
|
-
async def query_stream(
|
|
442
|
-
self,
|
|
443
|
-
user_query: str,
|
|
444
|
-
use_history: bool = True,
|
|
445
|
-
cancellation_token: asyncio.Event | None = None,
|
|
446
|
-
) -> AsyncIterator[StreamEvent]:
|
|
447
|
-
"""Process a user query and stream responses."""
|
|
448
|
-
# Initialize for tracking state
|
|
449
|
-
self._last_results = None
|
|
450
|
-
self._last_query = None
|
|
451
|
-
|
|
452
|
-
try:
|
|
453
|
-
# Ensure conversation is active for persistence
|
|
454
|
-
await self._ensure_conversation()
|
|
455
|
-
|
|
456
|
-
# Store user message in conversation history and persistence
|
|
457
|
-
if use_history:
|
|
458
|
-
self.conversation_history.append(
|
|
459
|
-
{"role": "user", "content": user_query}
|
|
460
|
-
)
|
|
461
|
-
await self._store_user_message(user_query)
|
|
462
|
-
|
|
463
|
-
# Build messages with history if requested
|
|
464
|
-
messages = []
|
|
465
|
-
if use_history:
|
|
466
|
-
messages = self._convert_history_to_messages()
|
|
467
|
-
|
|
468
|
-
# For OAuth with no history, inject SQL assistant instructions as first user message
|
|
469
|
-
is_oauth = hasattr(self.client, "use_oauth") and self.client.use_oauth
|
|
470
|
-
if is_oauth and not messages:
|
|
471
|
-
instructions = self._get_sql_assistant_instructions()
|
|
472
|
-
messages.append(Message(MessageRole.USER, instructions))
|
|
473
|
-
|
|
474
|
-
# Add current user message if not already in messages from history
|
|
475
|
-
if not use_history:
|
|
476
|
-
messages.append(Message(MessageRole.USER, user_query))
|
|
477
|
-
|
|
478
|
-
# Create initial request and get response
|
|
479
|
-
request = self._create_message_request(messages)
|
|
480
|
-
response = None
|
|
481
|
-
|
|
482
|
-
async for event in self._handle_stream_events(
|
|
483
|
-
self.client.create_message_with_tools(request, cancellation_token),
|
|
484
|
-
cancellation_token,
|
|
485
|
-
):
|
|
486
|
-
if isinstance(event, StreamEvent):
|
|
487
|
-
yield event
|
|
488
|
-
else:
|
|
489
|
-
response = event
|
|
490
|
-
|
|
491
|
-
# Handle tool use cycles
|
|
492
|
-
collected_content = []
|
|
493
|
-
while response is not None and response.stop_reason == "tool_use":
|
|
494
|
-
if cancellation_token is not None and cancellation_token.is_set():
|
|
495
|
-
return
|
|
496
|
-
|
|
497
|
-
# Add assistant's response to conversation
|
|
498
|
-
assistant_content = {"role": "assistant", "content": response.content}
|
|
499
|
-
collected_content.append(assistant_content)
|
|
500
|
-
|
|
501
|
-
# Store the assistant message immediately (not from collected_content)
|
|
502
|
-
if use_history:
|
|
503
|
-
await self._store_assistant_message(response.content)
|
|
504
|
-
|
|
505
|
-
# Execute tools and get results
|
|
506
|
-
tool_results = []
|
|
507
|
-
async for event in self._execute_and_yield_tool_results(
|
|
508
|
-
response.content, cancellation_token
|
|
509
|
-
):
|
|
510
|
-
if isinstance(event, StreamEvent):
|
|
511
|
-
yield event
|
|
512
|
-
elif isinstance(event, list):
|
|
513
|
-
tool_results = event
|
|
514
|
-
|
|
515
|
-
# Continue conversation with tool results
|
|
516
|
-
tool_content = {"role": "user", "content": tool_results}
|
|
517
|
-
collected_content.append(tool_content)
|
|
518
|
-
|
|
519
|
-
# Store the tool message immediately and update history
|
|
520
|
-
if use_history:
|
|
521
|
-
# Only add the NEW messages to history (not the accumulated ones)
|
|
522
|
-
# collected_content has [assistant1, tool1, assistant2, tool2, ...]
|
|
523
|
-
# We only want to add the last 2 items that were just added
|
|
524
|
-
new_messages_for_history = collected_content[
|
|
525
|
-
-2:
|
|
526
|
-
] # Last assistant + tool pair
|
|
527
|
-
self.conversation_history.extend(new_messages_for_history)
|
|
528
|
-
await self._store_tool_message(tool_results)
|
|
529
|
-
|
|
530
|
-
if cancellation_token is not None and cancellation_token.is_set():
|
|
531
|
-
return
|
|
532
|
-
|
|
533
|
-
yield StreamEvent("processing", "Analyzing results...")
|
|
534
|
-
|
|
535
|
-
# Build new messages with collected content
|
|
536
|
-
new_messages = messages.copy()
|
|
537
|
-
for content in collected_content:
|
|
538
|
-
if content["role"] == "user":
|
|
539
|
-
new_messages.append(
|
|
540
|
-
self._convert_tool_results_to_message(content["content"])
|
|
541
|
-
)
|
|
542
|
-
elif content["role"] == "assistant":
|
|
543
|
-
new_messages.append(
|
|
544
|
-
self._convert_response_content_to_message(
|
|
545
|
-
content["content"]
|
|
546
|
-
)
|
|
547
|
-
)
|
|
548
|
-
|
|
549
|
-
# Get next response
|
|
550
|
-
request = self._create_message_request(new_messages)
|
|
551
|
-
response = None
|
|
552
|
-
|
|
553
|
-
async for event in self._handle_stream_events(
|
|
554
|
-
self.client.create_message_with_tools(request, cancellation_token),
|
|
555
|
-
cancellation_token,
|
|
556
|
-
):
|
|
557
|
-
if isinstance(event, StreamEvent):
|
|
558
|
-
yield event
|
|
559
|
-
else:
|
|
560
|
-
response = event
|
|
561
|
-
|
|
562
|
-
# Update conversation history with final response
|
|
563
|
-
if use_history and response is not None:
|
|
564
|
-
self.conversation_history.append(
|
|
565
|
-
{"role": "assistant", "content": response.content}
|
|
566
|
-
)
|
|
567
|
-
|
|
568
|
-
# Store final assistant message in persistence (only if not tool_use)
|
|
569
|
-
if response.stop_reason != "tool_use":
|
|
570
|
-
await self._store_assistant_message(response.content)
|
|
571
|
-
|
|
572
|
-
except asyncio.CancelledError:
|
|
573
|
-
return
|
|
574
|
-
except Exception as e:
|
|
575
|
-
yield StreamEvent("error", str(e))
|
|
576
|
-
|
|
577
|
-
async def close(self):
|
|
578
|
-
"""Close the client."""
|
|
579
|
-
await self.client.close()
|
sqlsaber/agents/streaming.py
DELETED
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
"""Streaming utilities for agents."""
|
|
2
|
-
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class StreamingResponse:
|
|
7
|
-
"""Helper class to manage streaming response construction."""
|
|
8
|
-
|
|
9
|
-
def __init__(self, content: list[dict[str, Any]], stop_reason: str):
|
|
10
|
-
self.content = content
|
|
11
|
-
self.stop_reason = stop_reason
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def build_tool_result_block(tool_use_id: str, content: str) -> dict[str, Any]:
|
|
15
|
-
"""Build a tool result block for the conversation."""
|
|
16
|
-
return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|