sqlsaber 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

sqlsaber/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """SQLSaber CLI - SQL like Claude Code."""
2
+
3
+ __version__ = "0.1.0"
sqlsaber/__main__.py ADDED
@@ -0,0 +1,4 @@
1
+ from sqlsaber.cli.commands import main
2
+
3
+ if __name__ == "__main__":
4
+ main()
@@ -0,0 +1,9 @@
1
+ """Agents module for SQLSaber."""
2
+
3
+ from .anthropic import AnthropicSQLAgent
4
+ from .base import BaseSQLAgent
5
+
6
+ __all__ = [
7
+ "BaseSQLAgent",
8
+ "AnthropicSQLAgent",
9
+ ]
@@ -0,0 +1,451 @@
1
+ """Anthropic-specific SQL agent implementation."""
2
+
3
+ import json
4
+ from typing import Any, AsyncIterator, Dict, List, Optional
5
+
6
+ from anthropic import AsyncAnthropic
7
+
8
+ from sqlsaber.agents.base import BaseSQLAgent
9
+ from sqlsaber.agents.streaming import (
10
+ StreamingResponse,
11
+ build_tool_result_block,
12
+ )
13
+ from sqlsaber.config.settings import Config
14
+ from sqlsaber.database.connection import (
15
+ BaseDatabaseConnection,
16
+ MySQLConnection,
17
+ PostgreSQLConnection,
18
+ SQLiteConnection,
19
+ )
20
+ from sqlsaber.database.schema import SchemaManager
21
+ from sqlsaber.memory.manager import MemoryManager
22
+ from sqlsaber.models.events import StreamEvent
23
+ from sqlsaber.models.types import ToolDefinition
24
+
25
+
26
+ class AnthropicSQLAgent(BaseSQLAgent):
27
+ """SQL Agent using Anthropic SDK directly."""
28
+
29
+ def __init__(
30
+ self, db_connection: BaseDatabaseConnection, database_name: Optional[str] = None
31
+ ):
32
+ super().__init__(db_connection)
33
+
34
+ config = Config()
35
+ config.validate() # This will raise ValueError if API key is missing
36
+
37
+ self.client = AsyncAnthropic(api_key=config.api_key)
38
+ self.model = config.model_name.replace("anthropic:", "")
39
+ self.schema_manager = SchemaManager(db_connection)
40
+
41
+ self.database_name = database_name
42
+ self.memory_manager = MemoryManager()
43
+
44
+ # Track last query results for streaming
45
+ self._last_results = None
46
+ self._last_query = None
47
+
48
+ # Define tools in Anthropic format
49
+ self.tools: List[ToolDefinition] = [
50
+ {
51
+ "name": "list_tables",
52
+ "description": "Get a list of all tables in the database with row counts. Use this first to discover available tables.",
53
+ "input_schema": {
54
+ "type": "object",
55
+ "properties": {},
56
+ "required": [],
57
+ },
58
+ },
59
+ {
60
+ "name": "introspect_schema",
61
+ "description": "Introspect database schema to understand table structures.",
62
+ "input_schema": {
63
+ "type": "object",
64
+ "properties": {
65
+ "table_pattern": {
66
+ "type": "string",
67
+ "description": "Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')",
68
+ }
69
+ },
70
+ "required": [],
71
+ },
72
+ },
73
+ {
74
+ "name": "execute_sql",
75
+ "description": "Execute a SQL query against the database.",
76
+ "input_schema": {
77
+ "type": "object",
78
+ "properties": {
79
+ "query": {
80
+ "type": "string",
81
+ "description": "SQL query to execute",
82
+ },
83
+ "limit": {
84
+ "type": "integer",
85
+ "description": "Maximum number of rows to return (default: 100)",
86
+ "default": 100,
87
+ },
88
+ },
89
+ "required": ["query"],
90
+ },
91
+ },
92
+ ]
93
+
94
+ # Build system prompt with memories if available
95
+ self.system_prompt = self._build_system_prompt()
96
+
97
+ def _get_database_type_name(self) -> str:
98
+ """Get the human-readable database type name."""
99
+ if isinstance(self.db, PostgreSQLConnection):
100
+ return "PostgreSQL"
101
+ elif isinstance(self.db, MySQLConnection):
102
+ return "MySQL"
103
+ elif isinstance(self.db, SQLiteConnection):
104
+ return "SQLite"
105
+ else:
106
+ return "database" # Fallback
107
+
108
+ def _build_system_prompt(self) -> str:
109
+ """Build system prompt with optional memory context."""
110
+ db_type = self._get_database_type_name()
111
+ base_prompt = f"""You are a helpful SQL assistant that helps users query their {db_type} database.
112
+
113
+ Your responsibilities:
114
+ 1. Understand user's natural language requests, think and convert them to SQL
115
+ 2. Use the provided tools efficiently to explore database schema
116
+ 3. Generate appropriate SQL queries
117
+ 4. Execute queries safely (only SELECT queries unless explicitly allowed)
118
+ 5. Format and explain results clearly
119
+
120
+ IMPORTANT - Schema Discovery Strategy:
121
+ 1. ALWAYS start with 'list_tables' to see available tables and row counts
122
+ 2. Based on the user's query, identify which specific tables are relevant
123
+ 3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
124
+
125
+ Guidelines:
126
+ - Use list_tables first, then introspect_schema for specific tables only
127
+ - Use table patterns like 'sample%' or '%experiment%' to filter related tables
128
+ - Use proper JOIN syntax and avoid cartesian products
129
+ - Include appropriate WHERE clauses to limit results
130
+ - Explain what the query does in simple terms
131
+ - Handle errors gracefully and suggest fixes
132
+ - Be security conscious - use parameterized queries when needed
133
+ """
134
+
135
+ # Add memory context if database name is available
136
+ if self.database_name:
137
+ memory_context = self.memory_manager.format_memories_for_prompt(
138
+ self.database_name
139
+ )
140
+ if memory_context.strip():
141
+ base_prompt += memory_context
142
+
143
+ return base_prompt
144
+
145
+ def add_memory(self, content: str) -> Optional[str]:
146
+ """Add a memory for the current database."""
147
+ if not self.database_name:
148
+ return None
149
+
150
+ memory = self.memory_manager.add_memory(self.database_name, content)
151
+ # Rebuild system prompt with new memory
152
+ self.system_prompt = self._build_system_prompt()
153
+ return memory.id
154
+
155
+ async def introspect_schema(self, table_pattern: Optional[str] = None) -> str:
156
+ """Introspect database schema to understand table structures."""
157
+ try:
158
+ # Pass table_pattern to get_schema_info for efficient filtering at DB level
159
+ schema_info = await self.schema_manager.get_schema_info(table_pattern)
160
+
161
+ # Format the schema information
162
+ formatted_info = {}
163
+ for table_name, table_info in schema_info.items():
164
+ formatted_info[table_name] = {
165
+ "columns": {
166
+ col_name: {
167
+ "type": col_info["data_type"],
168
+ "nullable": col_info["nullable"],
169
+ "default": col_info["default"],
170
+ }
171
+ for col_name, col_info in table_info["columns"].items()
172
+ },
173
+ "primary_keys": table_info["primary_keys"],
174
+ "foreign_keys": [
175
+ f"{fk['column']} -> {fk['references']['table']}.{fk['references']['column']}"
176
+ for fk in table_info["foreign_keys"]
177
+ ],
178
+ }
179
+
180
+ return json.dumps(formatted_info)
181
+ except Exception as e:
182
+ return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
183
+
184
+ async def list_tables(self) -> str:
185
+ """List all tables in the database with basic information."""
186
+ try:
187
+ tables_info = await self.schema_manager.list_tables()
188
+ return json.dumps(tables_info)
189
+ except Exception as e:
190
+ return json.dumps({"error": f"Error listing tables: {str(e)}"})
191
+
192
+ async def execute_sql(self, query: str, limit: Optional[int] = 100) -> str:
193
+ """Execute a SQL query against the database."""
194
+ try:
195
+ # Security check - only allow SELECT queries unless write is enabled
196
+ write_error = self._validate_write_operation(query)
197
+ if write_error:
198
+ return json.dumps(
199
+ {
200
+ "error": write_error,
201
+ }
202
+ )
203
+
204
+ # Add LIMIT if not present and it's a SELECT query
205
+ query = self._add_limit_to_query(query, limit)
206
+
207
+ # Execute the query (wrapped in a transaction for safety)
208
+ results = await self.db.execute_query(query)
209
+
210
+ # Format results - but also store the actual data
211
+ actual_limit = limit if limit is not None else len(results)
212
+ self._last_results = results[:actual_limit]
213
+ self._last_query = query
214
+
215
+ return json.dumps(
216
+ {
217
+ "success": True,
218
+ "row_count": len(results),
219
+ "results": results[:actual_limit], # Extra safety for limit
220
+ "truncated": len(results) > actual_limit,
221
+ }
222
+ )
223
+
224
+ except Exception as e:
225
+ error_msg = str(e)
226
+
227
+ # Provide helpful error messages
228
+ suggestions = []
229
+ if "column" in error_msg.lower() and "does not exist" in error_msg.lower():
230
+ suggestions.append(
231
+ "Check column names using the schema introspection tool"
232
+ )
233
+ elif "table" in error_msg.lower() and "does not exist" in error_msg.lower():
234
+ suggestions.append(
235
+ "Check table names using the schema introspection tool"
236
+ )
237
+ elif "syntax error" in error_msg.lower():
238
+ suggestions.append(
239
+ "Review SQL syntax, especially JOIN conditions and WHERE clauses"
240
+ )
241
+
242
+ return json.dumps({"error": error_msg, "suggestions": suggestions})
243
+
244
+ async def process_tool_call(
245
+ self, tool_name: str, tool_input: Dict[str, Any]
246
+ ) -> str:
247
+ """Process a tool call and return the result."""
248
+ if tool_name == "list_tables":
249
+ return await self.list_tables()
250
+ elif tool_name == "introspect_schema":
251
+ return await self.introspect_schema(tool_input.get("table_pattern"))
252
+ elif tool_name == "execute_sql":
253
+ return await self.execute_sql(
254
+ tool_input["query"], tool_input.get("limit", 100)
255
+ )
256
+ else:
257
+ return json.dumps({"error": f"Unknown tool: {tool_name}"})
258
+
259
+ async def _process_stream_events(
260
+ self, stream, content_blocks: List[Dict], tool_use_blocks: List[Dict]
261
+ ) -> AsyncIterator[StreamEvent]:
262
+ """Process stream events and yield appropriate StreamEvents."""
263
+ async for event in stream:
264
+ if event.type == "content_block_start":
265
+ if hasattr(event.content_block, "type"):
266
+ if event.content_block.type == "tool_use":
267
+ yield StreamEvent(
268
+ "tool_use",
269
+ {"name": event.content_block.name, "status": "started"},
270
+ )
271
+ tool_use_blocks.append(
272
+ {
273
+ "id": event.content_block.id,
274
+ "name": event.content_block.name,
275
+ "input": {},
276
+ }
277
+ )
278
+ elif event.content_block.type == "text":
279
+ content_blocks.append({"type": "text", "text": ""})
280
+
281
+ elif event.type == "content_block_delta":
282
+ if hasattr(event.delta, "text"):
283
+ yield StreamEvent("text", event.delta.text)
284
+ if content_blocks and content_blocks[-1]["type"] == "text":
285
+ content_blocks[-1]["text"] += event.delta.text
286
+ elif hasattr(event.delta, "partial_json"):
287
+ if tool_use_blocks:
288
+ try:
289
+ current_json = tool_use_blocks[-1].get("_partial", "")
290
+ current_json += event.delta.partial_json
291
+ tool_use_blocks[-1]["_partial"] = current_json
292
+ tool_use_blocks[-1]["input"] = json.loads(current_json)
293
+ except json.JSONDecodeError:
294
+ pass
295
+
296
+ elif event.type == "message_stop":
297
+ break
298
+
299
+ def _finalize_tool_blocks(self, tool_use_blocks: List[Dict]) -> str:
300
+ """Finalize tool use blocks and return stop reason."""
301
+ if tool_use_blocks:
302
+ for block in tool_use_blocks:
303
+ block["type"] = "tool_use"
304
+ if "_partial" in block:
305
+ del block["_partial"]
306
+ return "tool_use"
307
+ return "stop"
308
+
309
+ async def _process_tool_results(
310
+ self, response: StreamingResponse
311
+ ) -> AsyncIterator[StreamEvent]:
312
+ """Process tool results and yield appropriate events."""
313
+ tool_results = []
314
+ for block in response.content:
315
+ if block.get("type") == "tool_use":
316
+ yield StreamEvent(
317
+ "tool_use",
318
+ {
319
+ "name": block["name"],
320
+ "input": block["input"],
321
+ "status": "executing",
322
+ },
323
+ )
324
+
325
+ tool_result = await self.process_tool_call(
326
+ block["name"], block["input"]
327
+ )
328
+
329
+ # Yield specific events based on tool type
330
+ if block["name"] == "execute_sql" and self._last_results:
331
+ yield StreamEvent(
332
+ "query_result",
333
+ {
334
+ "query": self._last_query,
335
+ "results": self._last_results,
336
+ },
337
+ )
338
+ elif block["name"] in ["list_tables", "introspect_schema"]:
339
+ yield StreamEvent(
340
+ "tool_result",
341
+ {
342
+ "tool_name": block["name"],
343
+ "result": tool_result,
344
+ },
345
+ )
346
+
347
+ tool_results.append(build_tool_result_block(block["id"], tool_result))
348
+
349
+ yield StreamEvent("tool_result_data", tool_results)
350
+
351
+ async def query_stream(
352
+ self, user_query: str, use_history: bool = True
353
+ ) -> AsyncIterator[StreamEvent]:
354
+ """Process a user query and stream responses."""
355
+ # Initialize for tracking state
356
+ self._last_results = None
357
+ self._last_query = None
358
+
359
+ # Build messages with history if requested
360
+ if use_history:
361
+ messages = self.conversation_history + [
362
+ {"role": "user", "content": user_query}
363
+ ]
364
+ else:
365
+ messages = [{"role": "user", "content": user_query}]
366
+
367
+ try:
368
+ # Create initial stream and get response
369
+ response = None
370
+ async for event in self._create_and_process_stream(messages):
371
+ if event.type == "response_ready":
372
+ response = event.data
373
+ else:
374
+ yield event
375
+
376
+ collected_content = []
377
+
378
+ # Process tool calls if needed
379
+ while response is not None and response.stop_reason == "tool_use":
380
+ # Add assistant's response to conversation
381
+ collected_content.append(
382
+ {"role": "assistant", "content": response.content}
383
+ )
384
+
385
+ # Process tool results
386
+ tool_results = []
387
+ async for event in self._process_tool_results(response):
388
+ if event.type == "tool_result_data":
389
+ tool_results = event.data
390
+ else:
391
+ yield event
392
+
393
+ # Continue conversation with tool results
394
+ collected_content.append({"role": "user", "content": tool_results})
395
+
396
+ # Signal that we're processing the tool results
397
+ yield StreamEvent("processing", "Analyzing results...")
398
+
399
+ # Get next response
400
+ response = None
401
+ async for event in self._create_and_process_stream(
402
+ messages + collected_content
403
+ ):
404
+ if event.type == "response_ready":
405
+ response = event.data
406
+ else:
407
+ yield event
408
+
409
+ # Update conversation history if using history
410
+ if use_history:
411
+ self.conversation_history.append(
412
+ {"role": "user", "content": user_query}
413
+ )
414
+ self.conversation_history.extend(collected_content)
415
+ # Add final assistant response
416
+ if response is not None:
417
+ self.conversation_history.append(
418
+ {"role": "assistant", "content": response.content}
419
+ )
420
+
421
+ except Exception as e:
422
+ yield StreamEvent("error", str(e))
423
+
424
+ async def _create_and_process_stream(
425
+ self, messages: List[Dict]
426
+ ) -> AsyncIterator[StreamEvent]:
427
+ """Create a stream and yield events while building response."""
428
+ stream = await self.client.messages.create(
429
+ model=self.model,
430
+ max_tokens=4096,
431
+ system=self.system_prompt,
432
+ messages=messages,
433
+ tools=self.tools,
434
+ stream=True,
435
+ )
436
+
437
+ content_blocks = []
438
+ tool_use_blocks = []
439
+
440
+ async for event in self._process_stream_events(
441
+ stream, content_blocks, tool_use_blocks
442
+ ):
443
+ yield event
444
+
445
+ # Finalize tool blocks and create response
446
+ stop_reason = self._finalize_tool_blocks(tool_use_blocks)
447
+ content_blocks.extend(tool_use_blocks)
448
+
449
+ yield StreamEvent(
450
+ "response_ready", StreamingResponse(content_blocks, stop_reason)
451
+ )
@@ -0,0 +1,67 @@
1
+ """Abstract base class for SQL agents."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, AsyncIterator, Dict, List, Optional
5
+
6
+ from sqlsaber.database.connection import BaseDatabaseConnection
7
+ from sqlsaber.models.events import StreamEvent
8
+
9
+
10
+ class BaseSQLAgent(ABC):
11
+ """Abstract base class for SQL agents."""
12
+
13
+ def __init__(self, db_connection: BaseDatabaseConnection):
14
+ self.db = db_connection
15
+ self.conversation_history: List[Dict[str, Any]] = []
16
+
17
+ @abstractmethod
18
+ async def query_stream(
19
+ self, user_query: str, use_history: bool = True
20
+ ) -> AsyncIterator[StreamEvent]:
21
+ """Process a user query and stream responses."""
22
+ pass
23
+
24
+ def clear_history(self):
25
+ """Clear conversation history."""
26
+ self.conversation_history = []
27
+
28
+ @abstractmethod
29
+ async def process_tool_call(
30
+ self, tool_name: str, tool_input: Dict[str, Any]
31
+ ) -> str:
32
+ """Process a tool call and return the result."""
33
+ pass
34
+
35
+ def _validate_write_operation(self, query: str) -> Optional[str]:
36
+ """Validate if a write operation is allowed.
37
+
38
+ Returns:
39
+ None if operation is allowed, error message if not allowed.
40
+ """
41
+ query_upper = query.strip().upper()
42
+
43
+ # Check for write operations
44
+ write_keywords = [
45
+ "INSERT",
46
+ "UPDATE",
47
+ "DELETE",
48
+ "DROP",
49
+ "CREATE",
50
+ "ALTER",
51
+ "TRUNCATE",
52
+ ]
53
+ is_write_query = any(query_upper.startswith(kw) for kw in write_keywords)
54
+
55
+ if is_write_query:
56
+ return (
57
+ "Write operations are not allowed. Only SELECT queries are permitted."
58
+ )
59
+
60
+ return None
61
+
62
+ def _add_limit_to_query(self, query: str, limit: int = 100) -> str:
63
+ """Add LIMIT clause to SELECT queries if not present."""
64
+ query_upper = query.strip().upper()
65
+ if query_upper.startswith("SELECT") and "LIMIT" not in query_upper:
66
+ return f"{query.rstrip(';')} LIMIT {limit};"
67
+ return query
@@ -0,0 +1,26 @@
1
+ """Streaming utilities for agents."""
2
+
3
+ from typing import Any, Dict, List
4
+
5
+
6
+ class StreamingResponse:
7
+ """Helper class to manage streaming response construction."""
8
+
9
+ def __init__(self, content: List[Dict[str, Any]], stop_reason: str):
10
+ self.content = content
11
+ self.stop_reason = stop_reason
12
+
13
+
14
+ def build_tool_result_block(tool_use_id: str, content: str) -> Dict[str, Any]:
15
+ """Build a tool result block for the conversation."""
16
+ return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
17
+
18
+
19
+ def extract_sql_from_text(text: str) -> str:
20
+ """Extract SQL query from markdown-formatted text."""
21
+ if "```sql" in text:
22
+ sql_start = text.find("```sql") + 6
23
+ sql_end = text.find("```", sql_start)
24
+ if sql_end > sql_start:
25
+ return text[sql_start:sql_end].strip()
26
+ return ""
@@ -0,0 +1,7 @@
1
+ """CLI module for SQLSaber."""
2
+
3
+ from .commands import main
4
+
5
+ __all__ = [
6
+ "main",
7
+ ]