sqlsaber 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/__init__.py +3 -0
- sqlsaber/__main__.py +4 -0
- sqlsaber/agents/__init__.py +9 -0
- sqlsaber/agents/anthropic.py +451 -0
- sqlsaber/agents/base.py +67 -0
- sqlsaber/agents/streaming.py +26 -0
- sqlsaber/cli/__init__.py +7 -0
- sqlsaber/cli/commands.py +132 -0
- sqlsaber/cli/database.py +275 -0
- sqlsaber/cli/display.py +207 -0
- sqlsaber/cli/interactive.py +93 -0
- sqlsaber/cli/memory.py +239 -0
- sqlsaber/cli/models.py +231 -0
- sqlsaber/cli/streaming.py +94 -0
- sqlsaber/config/__init__.py +7 -0
- sqlsaber/config/api_keys.py +102 -0
- sqlsaber/config/database.py +252 -0
- sqlsaber/config/settings.py +115 -0
- sqlsaber/database/__init__.py +9 -0
- sqlsaber/database/connection.py +187 -0
- sqlsaber/database/schema.py +678 -0
- sqlsaber/memory/__init__.py +1 -0
- sqlsaber/memory/manager.py +77 -0
- sqlsaber/memory/storage.py +176 -0
- sqlsaber/models/__init__.py +13 -0
- sqlsaber/models/events.py +28 -0
- sqlsaber/models/types.py +40 -0
- sqlsaber-0.1.0.dist-info/METADATA +168 -0
- sqlsaber-0.1.0.dist-info/RECORD +32 -0
- sqlsaber-0.1.0.dist-info/WHEEL +4 -0
- sqlsaber-0.1.0.dist-info/entry_points.txt +4 -0
- sqlsaber-0.1.0.dist-info/licenses/LICENSE +201 -0
sqlsaber/__init__.py
ADDED
sqlsaber/__main__.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
"""Anthropic-specific SQL agent implementation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
from anthropic import AsyncAnthropic
|
|
7
|
+
|
|
8
|
+
from sqlsaber.agents.base import BaseSQLAgent
|
|
9
|
+
from sqlsaber.agents.streaming import (
|
|
10
|
+
StreamingResponse,
|
|
11
|
+
build_tool_result_block,
|
|
12
|
+
)
|
|
13
|
+
from sqlsaber.config.settings import Config
|
|
14
|
+
from sqlsaber.database.connection import (
|
|
15
|
+
BaseDatabaseConnection,
|
|
16
|
+
MySQLConnection,
|
|
17
|
+
PostgreSQLConnection,
|
|
18
|
+
SQLiteConnection,
|
|
19
|
+
)
|
|
20
|
+
from sqlsaber.database.schema import SchemaManager
|
|
21
|
+
from sqlsaber.memory.manager import MemoryManager
|
|
22
|
+
from sqlsaber.models.events import StreamEvent
|
|
23
|
+
from sqlsaber.models.types import ToolDefinition
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AnthropicSQLAgent(BaseSQLAgent):
|
|
27
|
+
"""SQL Agent using Anthropic SDK directly."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self, db_connection: BaseDatabaseConnection, database_name: Optional[str] = None
|
|
31
|
+
):
|
|
32
|
+
super().__init__(db_connection)
|
|
33
|
+
|
|
34
|
+
config = Config()
|
|
35
|
+
config.validate() # This will raise ValueError if API key is missing
|
|
36
|
+
|
|
37
|
+
self.client = AsyncAnthropic(api_key=config.api_key)
|
|
38
|
+
self.model = config.model_name.replace("anthropic:", "")
|
|
39
|
+
self.schema_manager = SchemaManager(db_connection)
|
|
40
|
+
|
|
41
|
+
self.database_name = database_name
|
|
42
|
+
self.memory_manager = MemoryManager()
|
|
43
|
+
|
|
44
|
+
# Track last query results for streaming
|
|
45
|
+
self._last_results = None
|
|
46
|
+
self._last_query = None
|
|
47
|
+
|
|
48
|
+
# Define tools in Anthropic format
|
|
49
|
+
self.tools: List[ToolDefinition] = [
|
|
50
|
+
{
|
|
51
|
+
"name": "list_tables",
|
|
52
|
+
"description": "Get a list of all tables in the database with row counts. Use this first to discover available tables.",
|
|
53
|
+
"input_schema": {
|
|
54
|
+
"type": "object",
|
|
55
|
+
"properties": {},
|
|
56
|
+
"required": [],
|
|
57
|
+
},
|
|
58
|
+
},
|
|
59
|
+
{
|
|
60
|
+
"name": "introspect_schema",
|
|
61
|
+
"description": "Introspect database schema to understand table structures.",
|
|
62
|
+
"input_schema": {
|
|
63
|
+
"type": "object",
|
|
64
|
+
"properties": {
|
|
65
|
+
"table_pattern": {
|
|
66
|
+
"type": "string",
|
|
67
|
+
"description": "Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')",
|
|
68
|
+
}
|
|
69
|
+
},
|
|
70
|
+
"required": [],
|
|
71
|
+
},
|
|
72
|
+
},
|
|
73
|
+
{
|
|
74
|
+
"name": "execute_sql",
|
|
75
|
+
"description": "Execute a SQL query against the database.",
|
|
76
|
+
"input_schema": {
|
|
77
|
+
"type": "object",
|
|
78
|
+
"properties": {
|
|
79
|
+
"query": {
|
|
80
|
+
"type": "string",
|
|
81
|
+
"description": "SQL query to execute",
|
|
82
|
+
},
|
|
83
|
+
"limit": {
|
|
84
|
+
"type": "integer",
|
|
85
|
+
"description": "Maximum number of rows to return (default: 100)",
|
|
86
|
+
"default": 100,
|
|
87
|
+
},
|
|
88
|
+
},
|
|
89
|
+
"required": ["query"],
|
|
90
|
+
},
|
|
91
|
+
},
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
# Build system prompt with memories if available
|
|
95
|
+
self.system_prompt = self._build_system_prompt()
|
|
96
|
+
|
|
97
|
+
def _get_database_type_name(self) -> str:
|
|
98
|
+
"""Get the human-readable database type name."""
|
|
99
|
+
if isinstance(self.db, PostgreSQLConnection):
|
|
100
|
+
return "PostgreSQL"
|
|
101
|
+
elif isinstance(self.db, MySQLConnection):
|
|
102
|
+
return "MySQL"
|
|
103
|
+
elif isinstance(self.db, SQLiteConnection):
|
|
104
|
+
return "SQLite"
|
|
105
|
+
else:
|
|
106
|
+
return "database" # Fallback
|
|
107
|
+
|
|
108
|
+
def _build_system_prompt(self) -> str:
|
|
109
|
+
"""Build system prompt with optional memory context."""
|
|
110
|
+
db_type = self._get_database_type_name()
|
|
111
|
+
base_prompt = f"""You are a helpful SQL assistant that helps users query their {db_type} database.
|
|
112
|
+
|
|
113
|
+
Your responsibilities:
|
|
114
|
+
1. Understand user's natural language requests, think and convert them to SQL
|
|
115
|
+
2. Use the provided tools efficiently to explore database schema
|
|
116
|
+
3. Generate appropriate SQL queries
|
|
117
|
+
4. Execute queries safely (only SELECT queries unless explicitly allowed)
|
|
118
|
+
5. Format and explain results clearly
|
|
119
|
+
|
|
120
|
+
IMPORTANT - Schema Discovery Strategy:
|
|
121
|
+
1. ALWAYS start with 'list_tables' to see available tables and row counts
|
|
122
|
+
2. Based on the user's query, identify which specific tables are relevant
|
|
123
|
+
3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
|
|
124
|
+
|
|
125
|
+
Guidelines:
|
|
126
|
+
- Use list_tables first, then introspect_schema for specific tables only
|
|
127
|
+
- Use table patterns like 'sample%' or '%experiment%' to filter related tables
|
|
128
|
+
- Use proper JOIN syntax and avoid cartesian products
|
|
129
|
+
- Include appropriate WHERE clauses to limit results
|
|
130
|
+
- Explain what the query does in simple terms
|
|
131
|
+
- Handle errors gracefully and suggest fixes
|
|
132
|
+
- Be security conscious - use parameterized queries when needed
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
# Add memory context if database name is available
|
|
136
|
+
if self.database_name:
|
|
137
|
+
memory_context = self.memory_manager.format_memories_for_prompt(
|
|
138
|
+
self.database_name
|
|
139
|
+
)
|
|
140
|
+
if memory_context.strip():
|
|
141
|
+
base_prompt += memory_context
|
|
142
|
+
|
|
143
|
+
return base_prompt
|
|
144
|
+
|
|
145
|
+
def add_memory(self, content: str) -> Optional[str]:
|
|
146
|
+
"""Add a memory for the current database."""
|
|
147
|
+
if not self.database_name:
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
memory = self.memory_manager.add_memory(self.database_name, content)
|
|
151
|
+
# Rebuild system prompt with new memory
|
|
152
|
+
self.system_prompt = self._build_system_prompt()
|
|
153
|
+
return memory.id
|
|
154
|
+
|
|
155
|
+
async def introspect_schema(self, table_pattern: Optional[str] = None) -> str:
|
|
156
|
+
"""Introspect database schema to understand table structures."""
|
|
157
|
+
try:
|
|
158
|
+
# Pass table_pattern to get_schema_info for efficient filtering at DB level
|
|
159
|
+
schema_info = await self.schema_manager.get_schema_info(table_pattern)
|
|
160
|
+
|
|
161
|
+
# Format the schema information
|
|
162
|
+
formatted_info = {}
|
|
163
|
+
for table_name, table_info in schema_info.items():
|
|
164
|
+
formatted_info[table_name] = {
|
|
165
|
+
"columns": {
|
|
166
|
+
col_name: {
|
|
167
|
+
"type": col_info["data_type"],
|
|
168
|
+
"nullable": col_info["nullable"],
|
|
169
|
+
"default": col_info["default"],
|
|
170
|
+
}
|
|
171
|
+
for col_name, col_info in table_info["columns"].items()
|
|
172
|
+
},
|
|
173
|
+
"primary_keys": table_info["primary_keys"],
|
|
174
|
+
"foreign_keys": [
|
|
175
|
+
f"{fk['column']} -> {fk['references']['table']}.{fk['references']['column']}"
|
|
176
|
+
for fk in table_info["foreign_keys"]
|
|
177
|
+
],
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
return json.dumps(formatted_info)
|
|
181
|
+
except Exception as e:
|
|
182
|
+
return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
|
|
183
|
+
|
|
184
|
+
async def list_tables(self) -> str:
|
|
185
|
+
"""List all tables in the database with basic information."""
|
|
186
|
+
try:
|
|
187
|
+
tables_info = await self.schema_manager.list_tables()
|
|
188
|
+
return json.dumps(tables_info)
|
|
189
|
+
except Exception as e:
|
|
190
|
+
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
191
|
+
|
|
192
|
+
async def execute_sql(self, query: str, limit: Optional[int] = 100) -> str:
|
|
193
|
+
"""Execute a SQL query against the database."""
|
|
194
|
+
try:
|
|
195
|
+
# Security check - only allow SELECT queries unless write is enabled
|
|
196
|
+
write_error = self._validate_write_operation(query)
|
|
197
|
+
if write_error:
|
|
198
|
+
return json.dumps(
|
|
199
|
+
{
|
|
200
|
+
"error": write_error,
|
|
201
|
+
}
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Add LIMIT if not present and it's a SELECT query
|
|
205
|
+
query = self._add_limit_to_query(query, limit)
|
|
206
|
+
|
|
207
|
+
# Execute the query (wrapped in a transaction for safety)
|
|
208
|
+
results = await self.db.execute_query(query)
|
|
209
|
+
|
|
210
|
+
# Format results - but also store the actual data
|
|
211
|
+
actual_limit = limit if limit is not None else len(results)
|
|
212
|
+
self._last_results = results[:actual_limit]
|
|
213
|
+
self._last_query = query
|
|
214
|
+
|
|
215
|
+
return json.dumps(
|
|
216
|
+
{
|
|
217
|
+
"success": True,
|
|
218
|
+
"row_count": len(results),
|
|
219
|
+
"results": results[:actual_limit], # Extra safety for limit
|
|
220
|
+
"truncated": len(results) > actual_limit,
|
|
221
|
+
}
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
except Exception as e:
|
|
225
|
+
error_msg = str(e)
|
|
226
|
+
|
|
227
|
+
# Provide helpful error messages
|
|
228
|
+
suggestions = []
|
|
229
|
+
if "column" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
230
|
+
suggestions.append(
|
|
231
|
+
"Check column names using the schema introspection tool"
|
|
232
|
+
)
|
|
233
|
+
elif "table" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
234
|
+
suggestions.append(
|
|
235
|
+
"Check table names using the schema introspection tool"
|
|
236
|
+
)
|
|
237
|
+
elif "syntax error" in error_msg.lower():
|
|
238
|
+
suggestions.append(
|
|
239
|
+
"Review SQL syntax, especially JOIN conditions and WHERE clauses"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
return json.dumps({"error": error_msg, "suggestions": suggestions})
|
|
243
|
+
|
|
244
|
+
async def process_tool_call(
|
|
245
|
+
self, tool_name: str, tool_input: Dict[str, Any]
|
|
246
|
+
) -> str:
|
|
247
|
+
"""Process a tool call and return the result."""
|
|
248
|
+
if tool_name == "list_tables":
|
|
249
|
+
return await self.list_tables()
|
|
250
|
+
elif tool_name == "introspect_schema":
|
|
251
|
+
return await self.introspect_schema(tool_input.get("table_pattern"))
|
|
252
|
+
elif tool_name == "execute_sql":
|
|
253
|
+
return await self.execute_sql(
|
|
254
|
+
tool_input["query"], tool_input.get("limit", 100)
|
|
255
|
+
)
|
|
256
|
+
else:
|
|
257
|
+
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
258
|
+
|
|
259
|
+
async def _process_stream_events(
|
|
260
|
+
self, stream, content_blocks: List[Dict], tool_use_blocks: List[Dict]
|
|
261
|
+
) -> AsyncIterator[StreamEvent]:
|
|
262
|
+
"""Process stream events and yield appropriate StreamEvents."""
|
|
263
|
+
async for event in stream:
|
|
264
|
+
if event.type == "content_block_start":
|
|
265
|
+
if hasattr(event.content_block, "type"):
|
|
266
|
+
if event.content_block.type == "tool_use":
|
|
267
|
+
yield StreamEvent(
|
|
268
|
+
"tool_use",
|
|
269
|
+
{"name": event.content_block.name, "status": "started"},
|
|
270
|
+
)
|
|
271
|
+
tool_use_blocks.append(
|
|
272
|
+
{
|
|
273
|
+
"id": event.content_block.id,
|
|
274
|
+
"name": event.content_block.name,
|
|
275
|
+
"input": {},
|
|
276
|
+
}
|
|
277
|
+
)
|
|
278
|
+
elif event.content_block.type == "text":
|
|
279
|
+
content_blocks.append({"type": "text", "text": ""})
|
|
280
|
+
|
|
281
|
+
elif event.type == "content_block_delta":
|
|
282
|
+
if hasattr(event.delta, "text"):
|
|
283
|
+
yield StreamEvent("text", event.delta.text)
|
|
284
|
+
if content_blocks and content_blocks[-1]["type"] == "text":
|
|
285
|
+
content_blocks[-1]["text"] += event.delta.text
|
|
286
|
+
elif hasattr(event.delta, "partial_json"):
|
|
287
|
+
if tool_use_blocks:
|
|
288
|
+
try:
|
|
289
|
+
current_json = tool_use_blocks[-1].get("_partial", "")
|
|
290
|
+
current_json += event.delta.partial_json
|
|
291
|
+
tool_use_blocks[-1]["_partial"] = current_json
|
|
292
|
+
tool_use_blocks[-1]["input"] = json.loads(current_json)
|
|
293
|
+
except json.JSONDecodeError:
|
|
294
|
+
pass
|
|
295
|
+
|
|
296
|
+
elif event.type == "message_stop":
|
|
297
|
+
break
|
|
298
|
+
|
|
299
|
+
def _finalize_tool_blocks(self, tool_use_blocks: List[Dict]) -> str:
|
|
300
|
+
"""Finalize tool use blocks and return stop reason."""
|
|
301
|
+
if tool_use_blocks:
|
|
302
|
+
for block in tool_use_blocks:
|
|
303
|
+
block["type"] = "tool_use"
|
|
304
|
+
if "_partial" in block:
|
|
305
|
+
del block["_partial"]
|
|
306
|
+
return "tool_use"
|
|
307
|
+
return "stop"
|
|
308
|
+
|
|
309
|
+
async def _process_tool_results(
|
|
310
|
+
self, response: StreamingResponse
|
|
311
|
+
) -> AsyncIterator[StreamEvent]:
|
|
312
|
+
"""Process tool results and yield appropriate events."""
|
|
313
|
+
tool_results = []
|
|
314
|
+
for block in response.content:
|
|
315
|
+
if block.get("type") == "tool_use":
|
|
316
|
+
yield StreamEvent(
|
|
317
|
+
"tool_use",
|
|
318
|
+
{
|
|
319
|
+
"name": block["name"],
|
|
320
|
+
"input": block["input"],
|
|
321
|
+
"status": "executing",
|
|
322
|
+
},
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
tool_result = await self.process_tool_call(
|
|
326
|
+
block["name"], block["input"]
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Yield specific events based on tool type
|
|
330
|
+
if block["name"] == "execute_sql" and self._last_results:
|
|
331
|
+
yield StreamEvent(
|
|
332
|
+
"query_result",
|
|
333
|
+
{
|
|
334
|
+
"query": self._last_query,
|
|
335
|
+
"results": self._last_results,
|
|
336
|
+
},
|
|
337
|
+
)
|
|
338
|
+
elif block["name"] in ["list_tables", "introspect_schema"]:
|
|
339
|
+
yield StreamEvent(
|
|
340
|
+
"tool_result",
|
|
341
|
+
{
|
|
342
|
+
"tool_name": block["name"],
|
|
343
|
+
"result": tool_result,
|
|
344
|
+
},
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
tool_results.append(build_tool_result_block(block["id"], tool_result))
|
|
348
|
+
|
|
349
|
+
yield StreamEvent("tool_result_data", tool_results)
|
|
350
|
+
|
|
351
|
+
async def query_stream(
|
|
352
|
+
self, user_query: str, use_history: bool = True
|
|
353
|
+
) -> AsyncIterator[StreamEvent]:
|
|
354
|
+
"""Process a user query and stream responses."""
|
|
355
|
+
# Initialize for tracking state
|
|
356
|
+
self._last_results = None
|
|
357
|
+
self._last_query = None
|
|
358
|
+
|
|
359
|
+
# Build messages with history if requested
|
|
360
|
+
if use_history:
|
|
361
|
+
messages = self.conversation_history + [
|
|
362
|
+
{"role": "user", "content": user_query}
|
|
363
|
+
]
|
|
364
|
+
else:
|
|
365
|
+
messages = [{"role": "user", "content": user_query}]
|
|
366
|
+
|
|
367
|
+
try:
|
|
368
|
+
# Create initial stream and get response
|
|
369
|
+
response = None
|
|
370
|
+
async for event in self._create_and_process_stream(messages):
|
|
371
|
+
if event.type == "response_ready":
|
|
372
|
+
response = event.data
|
|
373
|
+
else:
|
|
374
|
+
yield event
|
|
375
|
+
|
|
376
|
+
collected_content = []
|
|
377
|
+
|
|
378
|
+
# Process tool calls if needed
|
|
379
|
+
while response is not None and response.stop_reason == "tool_use":
|
|
380
|
+
# Add assistant's response to conversation
|
|
381
|
+
collected_content.append(
|
|
382
|
+
{"role": "assistant", "content": response.content}
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# Process tool results
|
|
386
|
+
tool_results = []
|
|
387
|
+
async for event in self._process_tool_results(response):
|
|
388
|
+
if event.type == "tool_result_data":
|
|
389
|
+
tool_results = event.data
|
|
390
|
+
else:
|
|
391
|
+
yield event
|
|
392
|
+
|
|
393
|
+
# Continue conversation with tool results
|
|
394
|
+
collected_content.append({"role": "user", "content": tool_results})
|
|
395
|
+
|
|
396
|
+
# Signal that we're processing the tool results
|
|
397
|
+
yield StreamEvent("processing", "Analyzing results...")
|
|
398
|
+
|
|
399
|
+
# Get next response
|
|
400
|
+
response = None
|
|
401
|
+
async for event in self._create_and_process_stream(
|
|
402
|
+
messages + collected_content
|
|
403
|
+
):
|
|
404
|
+
if event.type == "response_ready":
|
|
405
|
+
response = event.data
|
|
406
|
+
else:
|
|
407
|
+
yield event
|
|
408
|
+
|
|
409
|
+
# Update conversation history if using history
|
|
410
|
+
if use_history:
|
|
411
|
+
self.conversation_history.append(
|
|
412
|
+
{"role": "user", "content": user_query}
|
|
413
|
+
)
|
|
414
|
+
self.conversation_history.extend(collected_content)
|
|
415
|
+
# Add final assistant response
|
|
416
|
+
if response is not None:
|
|
417
|
+
self.conversation_history.append(
|
|
418
|
+
{"role": "assistant", "content": response.content}
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
except Exception as e:
|
|
422
|
+
yield StreamEvent("error", str(e))
|
|
423
|
+
|
|
424
|
+
async def _create_and_process_stream(
|
|
425
|
+
self, messages: List[Dict]
|
|
426
|
+
) -> AsyncIterator[StreamEvent]:
|
|
427
|
+
"""Create a stream and yield events while building response."""
|
|
428
|
+
stream = await self.client.messages.create(
|
|
429
|
+
model=self.model,
|
|
430
|
+
max_tokens=4096,
|
|
431
|
+
system=self.system_prompt,
|
|
432
|
+
messages=messages,
|
|
433
|
+
tools=self.tools,
|
|
434
|
+
stream=True,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
content_blocks = []
|
|
438
|
+
tool_use_blocks = []
|
|
439
|
+
|
|
440
|
+
async for event in self._process_stream_events(
|
|
441
|
+
stream, content_blocks, tool_use_blocks
|
|
442
|
+
):
|
|
443
|
+
yield event
|
|
444
|
+
|
|
445
|
+
# Finalize tool blocks and create response
|
|
446
|
+
stop_reason = self._finalize_tool_blocks(tool_use_blocks)
|
|
447
|
+
content_blocks.extend(tool_use_blocks)
|
|
448
|
+
|
|
449
|
+
yield StreamEvent(
|
|
450
|
+
"response_ready", StreamingResponse(content_blocks, stop_reason)
|
|
451
|
+
)
|
sqlsaber/agents/base.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Abstract base class for SQL agents."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
from sqlsaber.database.connection import BaseDatabaseConnection
|
|
7
|
+
from sqlsaber.models.events import StreamEvent
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseSQLAgent(ABC):
|
|
11
|
+
"""Abstract base class for SQL agents."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, db_connection: BaseDatabaseConnection):
|
|
14
|
+
self.db = db_connection
|
|
15
|
+
self.conversation_history: List[Dict[str, Any]] = []
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
async def query_stream(
|
|
19
|
+
self, user_query: str, use_history: bool = True
|
|
20
|
+
) -> AsyncIterator[StreamEvent]:
|
|
21
|
+
"""Process a user query and stream responses."""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
def clear_history(self):
|
|
25
|
+
"""Clear conversation history."""
|
|
26
|
+
self.conversation_history = []
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
async def process_tool_call(
|
|
30
|
+
self, tool_name: str, tool_input: Dict[str, Any]
|
|
31
|
+
) -> str:
|
|
32
|
+
"""Process a tool call and return the result."""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
def _validate_write_operation(self, query: str) -> Optional[str]:
|
|
36
|
+
"""Validate if a write operation is allowed.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
None if operation is allowed, error message if not allowed.
|
|
40
|
+
"""
|
|
41
|
+
query_upper = query.strip().upper()
|
|
42
|
+
|
|
43
|
+
# Check for write operations
|
|
44
|
+
write_keywords = [
|
|
45
|
+
"INSERT",
|
|
46
|
+
"UPDATE",
|
|
47
|
+
"DELETE",
|
|
48
|
+
"DROP",
|
|
49
|
+
"CREATE",
|
|
50
|
+
"ALTER",
|
|
51
|
+
"TRUNCATE",
|
|
52
|
+
]
|
|
53
|
+
is_write_query = any(query_upper.startswith(kw) for kw in write_keywords)
|
|
54
|
+
|
|
55
|
+
if is_write_query:
|
|
56
|
+
return (
|
|
57
|
+
"Write operations are not allowed. Only SELECT queries are permitted."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
def _add_limit_to_query(self, query: str, limit: int = 100) -> str:
|
|
63
|
+
"""Add LIMIT clause to SELECT queries if not present."""
|
|
64
|
+
query_upper = query.strip().upper()
|
|
65
|
+
if query_upper.startswith("SELECT") and "LIMIT" not in query_upper:
|
|
66
|
+
return f"{query.rstrip(';')} LIMIT {limit};"
|
|
67
|
+
return query
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Streaming utilities for agents."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class StreamingResponse:
|
|
7
|
+
"""Helper class to manage streaming response construction."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, content: List[Dict[str, Any]], stop_reason: str):
|
|
10
|
+
self.content = content
|
|
11
|
+
self.stop_reason = stop_reason
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def build_tool_result_block(tool_use_id: str, content: str) -> Dict[str, Any]:
|
|
15
|
+
"""Build a tool result block for the conversation."""
|
|
16
|
+
return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def extract_sql_from_text(text: str) -> str:
|
|
20
|
+
"""Extract SQL query from markdown-formatted text."""
|
|
21
|
+
if "```sql" in text:
|
|
22
|
+
sql_start = text.find("```sql") + 6
|
|
23
|
+
sql_end = text.find("```", sql_start)
|
|
24
|
+
if sql_end > sql_start:
|
|
25
|
+
return text[sql_start:sql_end].strip()
|
|
26
|
+
return ""
|