sqlsaber 0.7.0__tar.gz → 0.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

Files changed (67) hide show
  1. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/CHANGELOG.md +21 -0
  2. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/PKG-INFO +1 -1
  3. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/pyproject.toml +1 -1
  4. sqlsaber-0.8.0/src/sqlsaber/agents/anthropic.py +551 -0
  5. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/agents/base.py +11 -11
  6. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/agents/streaming.py +3 -3
  7. sqlsaber-0.8.0/src/sqlsaber/cli/auth.py +142 -0
  8. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/cli/commands.py +9 -4
  9. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/cli/completers.py +3 -5
  10. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/cli/database.py +9 -10
  11. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/cli/display.py +5 -7
  12. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/cli/interactive.py +2 -3
  13. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/cli/memory.py +7 -9
  14. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/cli/models.py +1 -2
  15. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/cli/streaming.py +5 -31
  16. sqlsaber-0.8.0/src/sqlsaber/clients/__init__.py +6 -0
  17. sqlsaber-0.8.0/src/sqlsaber/clients/anthropic.py +285 -0
  18. sqlsaber-0.8.0/src/sqlsaber/clients/base.py +31 -0
  19. sqlsaber-0.8.0/src/sqlsaber/clients/exceptions.py +117 -0
  20. sqlsaber-0.8.0/src/sqlsaber/clients/models.py +282 -0
  21. sqlsaber-0.8.0/src/sqlsaber/clients/streaming.py +257 -0
  22. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/config/api_keys.py +2 -3
  23. sqlsaber-0.8.0/src/sqlsaber/config/auth.py +86 -0
  24. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/config/database.py +20 -20
  25. sqlsaber-0.8.0/src/sqlsaber/config/oauth_flow.py +274 -0
  26. sqlsaber-0.8.0/src/sqlsaber/config/oauth_tokens.py +175 -0
  27. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/config/settings.py +34 -23
  28. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/database/connection.py +9 -9
  29. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/database/schema.py +25 -25
  30. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/mcp/mcp.py +3 -4
  31. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/memory/manager.py +3 -5
  32. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/memory/storage.py +7 -8
  33. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/models/events.py +4 -4
  34. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/models/types.py +10 -10
  35. sqlsaber-0.8.0/tests/test_agents/test_anthropic_oauth.py +78 -0
  36. sqlsaber-0.8.0/tests/test_clients/test_anthropic_client.py +91 -0
  37. sqlsaber-0.8.0/tests/test_clients/test_streaming.py +282 -0
  38. sqlsaber-0.8.0/tests/test_config/test_oauth.py +179 -0
  39. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/tests/test_config/test_settings.py +0 -37
  40. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/uv.lock +1 -1
  41. sqlsaber-0.7.0/src/sqlsaber/agents/anthropic.py +0 -444
  42. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/.github/workflows/publish.yml +0 -0
  43. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/.gitignore +0 -0
  44. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/.python-version +0 -0
  45. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/CLAUDE.md +0 -0
  46. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/LICENSE +0 -0
  47. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/README.md +0 -0
  48. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/pytest.ini +0 -0
  49. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/sqlsaber.svg +0 -0
  50. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/__init__.py +0 -0
  51. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/__main__.py +0 -0
  52. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/agents/__init__.py +0 -0
  53. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/agents/mcp.py +0 -0
  54. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/cli/__init__.py +0 -0
  55. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/config/__init__.py +0 -0
  56. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/database/__init__.py +0 -0
  57. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/mcp/__init__.py +0 -0
  58. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/memory/__init__.py +0 -0
  59. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/src/sqlsaber/models/__init__.py +0 -0
  60. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/tests/__init__.py +0 -0
  61. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/tests/conftest.py +0 -0
  62. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/tests/test_cli/__init__.py +0 -0
  63. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/tests/test_cli/test_commands.py +0 -0
  64. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/tests/test_config/__init__.py +0 -0
  65. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/tests/test_config/test_database.py +0 -0
  66. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/tests/test_database/__init__.py +0 -0
  67. {sqlsaber-0.7.0 → sqlsaber-0.8.0}/tests/test_database/test_connection.py +0 -0
@@ -4,6 +4,27 @@ All notable changes to SQLSaber will be documented in this file.
4
4
 
5
5
  ## [Unreleased]
6
6
 
7
+ ## [0.8.0] - 2025-07-07
8
+
9
+ ### Added
10
+
11
+ - OAuth support for Claude Pro/Max subscriptions
12
+ - Authentication management with `saber auth` command
13
+ - Interactive setup for API key or Claude Pro/Max subscription
14
+ - `saber auth setup`
15
+ - `saber auth status`
16
+ - `saber auth reset`
17
+ - Persistent storage of user authentication preferences
18
+ - New `clients` module with custom Anthropic API client
19
+ - `AnthropicClient` for direct API communication
20
+
21
+ ### Changed
22
+
23
+ - Enhanced authentication system to support both API keys and OAuth tokens
24
+ - Replaced Anthropic SDK with direct API implementation using httpx
25
+ - Modernized type annotations throughout the codebase
26
+ - Refactored query streaming into smaller, more maintainable functions
27
+
7
28
  ## [0.7.0] - 2025-07-01
8
29
 
9
30
  ### Added
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sqlsaber
3
- Version: 0.7.0
3
+ Version: 0.8.0
4
4
  Summary: SQLSaber - Agentic SQL assistant like Claude Code
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "sqlsaber"
3
- version = "0.7.0"
3
+ version = "0.8.0"
4
4
  description = "SQLSaber - Agentic SQL assistant like Claude Code"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -0,0 +1,551 @@
1
+ """Anthropic-specific SQL agent implementation using the custom client."""
2
+
3
+ import asyncio
4
+ import json
5
+ from typing import Any, AsyncIterator
6
+
7
+ from sqlsaber.agents.base import BaseSQLAgent
8
+ from sqlsaber.agents.streaming import (
9
+ build_tool_result_block,
10
+ )
11
+ from sqlsaber.clients import AnthropicClient
12
+ from sqlsaber.clients.models import (
13
+ ContentBlock,
14
+ ContentType,
15
+ CreateMessageRequest,
16
+ Message,
17
+ MessageRole,
18
+ ToolDefinition,
19
+ )
20
+ from sqlsaber.config.settings import Config
21
+ from sqlsaber.database.connection import BaseDatabaseConnection
22
+ from sqlsaber.memory.manager import MemoryManager
23
+ from sqlsaber.models.events import StreamEvent
24
+
25
+
26
+ class AnthropicSQLAgent(BaseSQLAgent):
27
+ """SQL Agent using the custom Anthropic client."""
28
+
29
+ # Constants
30
+ MAX_TOKENS = 4096
31
+ DEFAULT_SQL_LIMIT = 100
32
+
33
+ def __init__(
34
+ self, db_connection: BaseDatabaseConnection, database_name: str | None = None
35
+ ):
36
+ super().__init__(db_connection)
37
+
38
+ config = Config()
39
+ config.validate() # This will raise ValueError if credentials are missing
40
+
41
+ if config.oauth_token:
42
+ self.client = AnthropicClient(oauth_token=config.oauth_token)
43
+ else:
44
+ self.client = AnthropicClient(api_key=config.api_key)
45
+ self.model = config.model_name.replace("anthropic:", "")
46
+
47
+ self.database_name = database_name
48
+ self.memory_manager = MemoryManager()
49
+
50
+ # Track last query results for streaming
51
+ self._last_results = None
52
+ self._last_query = None
53
+
54
+ # Define tools in the new format
55
+ self.tools: list[ToolDefinition] = [
56
+ ToolDefinition(
57
+ name="list_tables",
58
+ description="Get a list of all tables in the database with row counts. Use this first to discover available tables.",
59
+ input_schema={
60
+ "type": "object",
61
+ "properties": {},
62
+ "required": [],
63
+ },
64
+ ),
65
+ ToolDefinition(
66
+ name="introspect_schema",
67
+ description="Introspect database schema to understand table structures.",
68
+ input_schema={
69
+ "type": "object",
70
+ "properties": {
71
+ "table_pattern": {
72
+ "type": "string",
73
+ "description": "Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')",
74
+ }
75
+ },
76
+ "required": [],
77
+ },
78
+ ),
79
+ ToolDefinition(
80
+ name="execute_sql",
81
+ description="Execute a SQL query against the database.",
82
+ input_schema={
83
+ "type": "object",
84
+ "properties": {
85
+ "query": {
86
+ "type": "string",
87
+ "description": "SQL query to execute",
88
+ },
89
+ "limit": {
90
+ "type": "integer",
91
+ "description": f"Maximum number of rows to return (default: {AnthropicSQLAgent.DEFAULT_SQL_LIMIT})",
92
+ "default": AnthropicSQLAgent.DEFAULT_SQL_LIMIT,
93
+ },
94
+ },
95
+ "required": ["query"],
96
+ },
97
+ ),
98
+ ToolDefinition(
99
+ name="plot_data",
100
+ description="Create a plot of query results.",
101
+ input_schema={
102
+ "type": "object",
103
+ "properties": {
104
+ "y_values": {
105
+ "type": "array",
106
+ "items": {"type": ["number", "null"]},
107
+ "description": "Y-axis data points (required)",
108
+ },
109
+ "x_values": {
110
+ "type": "array",
111
+ "items": {"type": ["number", "null"]},
112
+ "description": "X-axis data points (optional, will use indices if not provided)",
113
+ },
114
+ "plot_type": {
115
+ "type": "string",
116
+ "enum": ["line", "scatter", "histogram"],
117
+ "description": "Type of plot to create (default: line)",
118
+ "default": "line",
119
+ },
120
+ "title": {
121
+ "type": "string",
122
+ "description": "Title for the plot",
123
+ },
124
+ "x_label": {
125
+ "type": "string",
126
+ "description": "Label for X-axis",
127
+ },
128
+ "y_label": {
129
+ "type": "string",
130
+ "description": "Label for Y-axis",
131
+ },
132
+ },
133
+ "required": ["y_values"],
134
+ },
135
+ ),
136
+ ]
137
+
138
+ # Build system prompt with memories if available
139
+ self.system_prompt = self._build_system_prompt()
140
+
141
+ def _build_system_prompt(self) -> str:
142
+ """Build system prompt with optional memory context."""
143
+ # For OAuth authentication, start with Claude Code identity
144
+ # Check if we're using OAuth by looking at the client
145
+ is_oauth = (
146
+ hasattr(self, "client")
147
+ and hasattr(self.client, "use_oauth")
148
+ and self.client.use_oauth
149
+ )
150
+
151
+ if is_oauth:
152
+ # For OAuth, keep system prompt minimal - just Claude Code identity
153
+ return "You are Claude Code, Anthropic's official CLI for Claude."
154
+ else:
155
+ return self._get_sql_assistant_instructions()
156
+
157
+ def _get_sql_assistant_instructions(self) -> str:
158
+ """Get the detailed SQL assistant instructions."""
159
+ db_type = self._get_database_type_name()
160
+ instructions = f"""You are also a helpful SQL assistant that helps users query their {db_type} database.
161
+
162
+ Your responsibilities:
163
+ 1. Understand user's natural language requests, think and convert them to SQL
164
+ 2. Use the provided tools efficiently to explore database schema
165
+ 3. Generate appropriate SQL queries
166
+ 4. Execute queries safely - queries that modify the database are not allowed
167
+ 5. Format and explain results clearly
168
+ 6. Create visualizations when requested or when they would be helpful
169
+
170
+ IMPORTANT - Schema Discovery Strategy:
171
+ 1. ALWAYS start with 'list_tables' to see available tables and row counts
172
+ 2. Based on the user's query, identify which specific tables are relevant
173
+ 3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
174
+ 4. Timestamp columns must be converted to text when you write queries
175
+
176
+ Guidelines:
177
+ - Use list_tables first, then introspect_schema for specific tables only
178
+ - Use table patterns like 'sample%' or '%experiment%' to filter related tables
179
+ - Use proper JOIN syntax and avoid cartesian products
180
+ - Include appropriate WHERE clauses to limit results
181
+ - Explain what the query does in simple terms
182
+ - Handle errors gracefully and suggest fixes
183
+ - Be security conscious - use parameterized queries when needed
184
+ """
185
+
186
+ # Add memory context if database name is available
187
+ if self.database_name:
188
+ memory_context = self.memory_manager.format_memories_for_prompt(
189
+ self.database_name
190
+ )
191
+ if memory_context.strip():
192
+ instructions += memory_context
193
+
194
+ return instructions
195
+
196
+ def add_memory(self, content: str) -> str | None:
197
+ """Add a memory for the current database."""
198
+ if not self.database_name:
199
+ return None
200
+
201
+ memory = self.memory_manager.add_memory(self.database_name, content)
202
+ # Rebuild system prompt with new memory
203
+ self.system_prompt = self._build_system_prompt()
204
+ return memory.id
205
+
206
+ async def execute_sql(self, query: str, limit: int | None = None) -> str:
207
+ """Execute a SQL query against the database with streaming support."""
208
+ # Call parent implementation for core functionality
209
+ result = await super().execute_sql(query, limit)
210
+
211
+ # Parse result to extract data for streaming (AnthropicSQLAgent specific)
212
+ try:
213
+ result_data = json.loads(result)
214
+ if result_data.get("success") and "results" in result_data:
215
+ # Store results for streaming
216
+ actual_limit = (
217
+ limit if limit is not None else len(result_data["results"])
218
+ )
219
+ self._last_results = result_data["results"][:actual_limit]
220
+ self._last_query = query
221
+ except (json.JSONDecodeError, KeyError):
222
+ # If we can't parse the result, just continue without storing
223
+ pass
224
+
225
+ return result
226
+
227
+ async def process_tool_call(
228
+ self, tool_name: str, tool_input: dict[str, Any]
229
+ ) -> str:
230
+ """Process a tool call and return the result."""
231
+ # Use parent implementation for core tools
232
+ return await super().process_tool_call(tool_name, tool_input)
233
+
234
+ def _convert_user_message_to_message(
235
+ self, msg_content: str | list[dict[str, Any]]
236
+ ) -> Message:
237
+ """Convert user message content to Message object."""
238
+ if isinstance(msg_content, str):
239
+ return Message(MessageRole.USER, msg_content)
240
+
241
+ # Handle tool results format
242
+ tool_result_blocks = []
243
+ if isinstance(msg_content, list):
244
+ for item in msg_content:
245
+ if isinstance(item, dict) and item.get("type") == "tool_result":
246
+ tool_result_blocks.append(
247
+ ContentBlock(ContentType.TOOL_RESULT, item)
248
+ )
249
+
250
+ if tool_result_blocks:
251
+ return Message(MessageRole.USER, tool_result_blocks)
252
+
253
+ # Fallback to string representation
254
+ return Message(MessageRole.USER, str(msg_content))
255
+
256
+ def _convert_assistant_message_to_message(
257
+ self, msg_content: str | list[dict[str, Any]]
258
+ ) -> Message:
259
+ """Convert assistant message content to Message object."""
260
+ if isinstance(msg_content, str):
261
+ return Message(MessageRole.ASSISTANT, msg_content)
262
+
263
+ if isinstance(msg_content, list):
264
+ content_blocks = []
265
+ for block in msg_content:
266
+ if isinstance(block, dict):
267
+ if block.get("type") == "text":
268
+ text_content = block.get("text", "")
269
+ if text_content: # Only add non-empty text blocks
270
+ content_blocks.append(
271
+ ContentBlock(ContentType.TEXT, text_content)
272
+ )
273
+ elif block.get("type") == "tool_use":
274
+ content_blocks.append(
275
+ ContentBlock(
276
+ ContentType.TOOL_USE,
277
+ {
278
+ "id": block["id"],
279
+ "name": block["name"],
280
+ "input": block["input"],
281
+ },
282
+ )
283
+ )
284
+ if content_blocks:
285
+ return Message(MessageRole.ASSISTANT, content_blocks)
286
+
287
+ # Fallback to string representation
288
+ return Message(MessageRole.ASSISTANT, str(msg_content))
289
+
290
+ def _convert_history_to_messages(self) -> list[Message]:
291
+ """Convert conversation history to Message objects."""
292
+ messages = []
293
+ for msg in self.conversation_history:
294
+ if msg["role"] == "user":
295
+ messages.append(self._convert_user_message_to_message(msg["content"]))
296
+ elif msg["role"] == "assistant":
297
+ messages.append(
298
+ self._convert_assistant_message_to_message(msg["content"])
299
+ )
300
+ return messages
301
+
302
+ def _convert_tool_results_to_message(
303
+ self, tool_results: list[dict[str, Any]]
304
+ ) -> Message:
305
+ """Convert tool results to a user Message object."""
306
+ tool_result_blocks = []
307
+ for tool_result in tool_results:
308
+ tool_result_blocks.append(
309
+ ContentBlock(ContentType.TOOL_RESULT, tool_result)
310
+ )
311
+ return Message(MessageRole.USER, tool_result_blocks)
312
+
313
+ def _convert_response_content_to_message(
314
+ self, content: list[dict[str, Any]]
315
+ ) -> Message:
316
+ """Convert response content to assistant Message object."""
317
+ content_blocks = []
318
+ for block in content:
319
+ if block.get("type") == "text":
320
+ text_content = block["text"]
321
+ if text_content: # Only add non-empty text blocks
322
+ content_blocks.append(ContentBlock(ContentType.TEXT, text_content))
323
+ elif block.get("type") == "tool_use":
324
+ content_blocks.append(
325
+ ContentBlock(
326
+ ContentType.TOOL_USE,
327
+ {
328
+ "id": block["id"],
329
+ "name": block["name"],
330
+ "input": block["input"],
331
+ },
332
+ )
333
+ )
334
+ return Message(MessageRole.ASSISTANT, content_blocks)
335
+
336
+ async def _execute_and_yield_tool_results(
337
+ self,
338
+ response_content: list[dict[str, Any]],
339
+ cancellation_token: asyncio.Event | None = None,
340
+ ) -> AsyncIterator[StreamEvent | list[dict[str, Any]]]:
341
+ """Execute tool calls and yield appropriate stream events."""
342
+ tool_results = []
343
+
344
+ for block in response_content:
345
+ if block.get("type") == "tool_use":
346
+ # Check for cancellation before tool execution
347
+ if cancellation_token is not None and cancellation_token.is_set():
348
+ yield tool_results
349
+ return
350
+
351
+ yield StreamEvent(
352
+ "tool_use",
353
+ {
354
+ "name": block["name"],
355
+ "input": block["input"],
356
+ "status": "executing",
357
+ },
358
+ )
359
+
360
+ tool_result = await self.process_tool_call(
361
+ block["name"], block["input"]
362
+ )
363
+
364
+ # Yield specific events based on tool type
365
+ if block["name"] == "execute_sql" and self._last_results:
366
+ yield StreamEvent(
367
+ "query_result",
368
+ {
369
+ "query": self._last_query,
370
+ "results": self._last_results,
371
+ },
372
+ )
373
+ elif block["name"] in ["list_tables", "introspect_schema"]:
374
+ yield StreamEvent(
375
+ "tool_result",
376
+ {
377
+ "tool_name": block["name"],
378
+ "result": tool_result,
379
+ },
380
+ )
381
+ elif block["name"] == "plot_data":
382
+ yield StreamEvent(
383
+ "plot_result",
384
+ {
385
+ "tool_name": block["name"],
386
+ "input": block["input"],
387
+ "result": tool_result,
388
+ },
389
+ )
390
+
391
+ tool_results.append(build_tool_result_block(block["id"], tool_result))
392
+
393
+ yield tool_results
394
+
395
+ async def _handle_stream_events(
396
+ self,
397
+ stream_iterator: AsyncIterator[Any],
398
+ cancellation_token: asyncio.Event | None = None,
399
+ ) -> AsyncIterator[StreamEvent | Any]:
400
+ """Handle streaming events and yield stream events, return final response."""
401
+ response = None
402
+
403
+ async for event in stream_iterator:
404
+ if cancellation_token is not None and cancellation_token.is_set():
405
+ yield None
406
+ return
407
+
408
+ # Handle different event types
409
+ if hasattr(event, "type"):
410
+ if event.type == "content_block_start":
411
+ if hasattr(event.content_block, "type"):
412
+ if event.content_block.type == "tool_use":
413
+ yield StreamEvent(
414
+ "tool_use",
415
+ {
416
+ "name": event.content_block.name,
417
+ "status": "started",
418
+ },
419
+ )
420
+ elif event.type == "content_block_delta":
421
+ if hasattr(event.delta, "text"):
422
+ text = event.delta.text
423
+ if text is not None and text: # Only yield non-empty text
424
+ yield StreamEvent("text", text)
425
+ elif isinstance(event, dict) and event.get("type") == "response_ready":
426
+ response = event["data"]
427
+
428
+ yield response
429
+
430
+ def _create_message_request(self, messages: list[Message]) -> CreateMessageRequest:
431
+ """Create a CreateMessageRequest with standard parameters."""
432
+ return CreateMessageRequest(
433
+ model=self.model,
434
+ messages=messages,
435
+ max_tokens=self.MAX_TOKENS,
436
+ system=self.system_prompt,
437
+ tools=self.tools,
438
+ stream=True,
439
+ )
440
+
441
+ async def query_stream(
442
+ self,
443
+ user_query: str,
444
+ use_history: bool = True,
445
+ cancellation_token: asyncio.Event | None = None,
446
+ ) -> AsyncIterator[StreamEvent]:
447
+ """Process a user query and stream responses."""
448
+ # Initialize for tracking state
449
+ self._last_results = None
450
+ self._last_query = None
451
+
452
+ try:
453
+ # Build messages with history if requested
454
+ messages = []
455
+ if use_history:
456
+ messages = self._convert_history_to_messages()
457
+
458
+ # For OAuth with no history, inject SQL assistant instructions as first user message
459
+ is_oauth = hasattr(self.client, "use_oauth") and self.client.use_oauth
460
+ if is_oauth and not messages:
461
+ instructions = self._get_sql_assistant_instructions()
462
+ messages.append(Message(MessageRole.USER, instructions))
463
+
464
+ # Add current user message
465
+ messages.append(Message(MessageRole.USER, user_query))
466
+
467
+ # Create initial request and get response
468
+ request = self._create_message_request(messages)
469
+ response = None
470
+
471
+ async for event in self._handle_stream_events(
472
+ self.client.create_message_with_tools(request, cancellation_token),
473
+ cancellation_token,
474
+ ):
475
+ if isinstance(event, StreamEvent):
476
+ yield event
477
+ else:
478
+ response = event
479
+
480
+ # Handle tool use cycles
481
+ collected_content = []
482
+ while response is not None and response.stop_reason == "tool_use":
483
+ if cancellation_token is not None and cancellation_token.is_set():
484
+ return
485
+
486
+ # Add assistant's response to conversation
487
+ collected_content.append(
488
+ {"role": "assistant", "content": response.content}
489
+ )
490
+
491
+ # Execute tools and get results
492
+ tool_results = []
493
+ async for event in self._execute_and_yield_tool_results(
494
+ response.content, cancellation_token
495
+ ):
496
+ if isinstance(event, StreamEvent):
497
+ yield event
498
+ elif isinstance(event, list):
499
+ tool_results = event
500
+
501
+ # Continue conversation with tool results
502
+ collected_content.append({"role": "user", "content": tool_results})
503
+ if use_history:
504
+ self.conversation_history.extend(collected_content)
505
+
506
+ if cancellation_token is not None and cancellation_token.is_set():
507
+ return
508
+
509
+ yield StreamEvent("processing", "Analyzing results...")
510
+
511
+ # Build new messages with collected content
512
+ new_messages = messages.copy()
513
+ for content in collected_content:
514
+ if content["role"] == "user":
515
+ new_messages.append(
516
+ self._convert_tool_results_to_message(content["content"])
517
+ )
518
+ elif content["role"] == "assistant":
519
+ new_messages.append(
520
+ self._convert_response_content_to_message(
521
+ content["content"]
522
+ )
523
+ )
524
+
525
+ # Get next response
526
+ request = self._create_message_request(new_messages)
527
+ response = None
528
+
529
+ async for event in self._handle_stream_events(
530
+ self.client.create_message_with_tools(request, cancellation_token),
531
+ cancellation_token,
532
+ ):
533
+ if isinstance(event, StreamEvent):
534
+ yield event
535
+ else:
536
+ response = event
537
+
538
+ # Update conversation history with final response
539
+ if use_history and response is not None:
540
+ self.conversation_history.append(
541
+ {"role": "assistant", "content": response.content}
542
+ )
543
+
544
+ except asyncio.CancelledError:
545
+ return
546
+ except Exception as e:
547
+ yield StreamEvent("error", str(e))
548
+
549
+ async def close(self):
550
+ """Close the client."""
551
+ await self.client.close()
@@ -3,7 +3,7 @@
3
3
  import asyncio
4
4
  import json
5
5
  from abc import ABC, abstractmethod
6
- from typing import Any, AsyncIterator, Dict, List, Optional
6
+ from typing import Any, AsyncIterator
7
7
 
8
8
  from uniplot import histogram, plot
9
9
 
@@ -24,7 +24,7 @@ class BaseSQLAgent(ABC):
24
24
  def __init__(self, db_connection: BaseDatabaseConnection):
25
25
  self.db = db_connection
26
26
  self.schema_manager = SchemaManager(db_connection)
27
- self.conversation_history: List[Dict[str, Any]] = []
27
+ self.conversation_history: list[dict[str, Any]] = []
28
28
 
29
29
  @abstractmethod
30
30
  async def query_stream(
@@ -59,7 +59,7 @@ class BaseSQLAgent(ABC):
59
59
  else:
60
60
  return "database" # Fallback
61
61
 
62
- async def introspect_schema(self, table_pattern: Optional[str] = None) -> str:
62
+ async def introspect_schema(self, table_pattern: str | None = None) -> str:
63
63
  """Introspect database schema to understand table structures."""
64
64
  try:
65
65
  # Pass table_pattern to get_schema_info for efficient filtering at DB level
@@ -96,7 +96,7 @@ class BaseSQLAgent(ABC):
96
96
  except Exception as e:
97
97
  return json.dumps({"error": f"Error listing tables: {str(e)}"})
98
98
 
99
- async def execute_sql(self, query: str, limit: Optional[int] = None) -> str:
99
+ async def execute_sql(self, query: str, limit: int | None = None) -> str:
100
100
  """Execute a SQL query against the database."""
101
101
  try:
102
102
  # Security check - only allow SELECT queries unless write is enabled
@@ -147,7 +147,7 @@ class BaseSQLAgent(ABC):
147
147
  return json.dumps({"error": error_msg, "suggestions": suggestions})
148
148
 
149
149
  async def process_tool_call(
150
- self, tool_name: str, tool_input: Dict[str, Any]
150
+ self, tool_name: str, tool_input: dict[str, Any]
151
151
  ) -> str:
152
152
  """Process a tool call and return the result."""
153
153
  if tool_name == "list_tables":
@@ -170,7 +170,7 @@ class BaseSQLAgent(ABC):
170
170
  else:
171
171
  return json.dumps({"error": f"Unknown tool: {tool_name}"})
172
172
 
173
- def _validate_write_operation(self, query: str) -> Optional[str]:
173
+ def _validate_write_operation(self, query: str) -> str | None:
174
174
  """Validate if a write operation is allowed.
175
175
 
176
176
  Returns:
@@ -206,12 +206,12 @@ class BaseSQLAgent(ABC):
206
206
 
207
207
  async def plot_data(
208
208
  self,
209
- y_values: List[float],
210
- x_values: Optional[List[float]] = None,
209
+ y_values: list[float],
210
+ x_values: list[float] | None = None,
211
211
  plot_type: str = "line",
212
- title: Optional[str] = None,
213
- x_label: Optional[str] = None,
214
- y_label: Optional[str] = None,
212
+ title: str | None = None,
213
+ x_label: str | None = None,
214
+ y_label: str | None = None,
215
215
  ) -> str:
216
216
  """Create a terminal plot using uniplot.
217
217
 
@@ -1,16 +1,16 @@
1
1
  """Streaming utilities for agents."""
2
2
 
3
- from typing import Any, Dict, List
3
+ from typing import Any
4
4
 
5
5
 
6
6
  class StreamingResponse:
7
7
  """Helper class to manage streaming response construction."""
8
8
 
9
- def __init__(self, content: List[Dict[str, Any]], stop_reason: str):
9
+ def __init__(self, content: list[dict[str, Any]], stop_reason: str):
10
10
  self.content = content
11
11
  self.stop_reason = stop_reason
12
12
 
13
13
 
14
- def build_tool_result_block(tool_use_id: str, content: str) -> Dict[str, Any]:
14
+ def build_tool_result_block(tool_use_id: str, content: str) -> dict[str, Any]:
15
15
  """Build a tool result block for the conversation."""
16
16
  return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content}