sqlsaber 0.2.0__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber-0.3.0/CHANGELOG.md +66 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/PKG-INFO +2 -1
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/pyproject.toml +2 -1
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/agents/anthropic.py +19 -113
- sqlsaber-0.3.0/src/sqlsaber/agents/base.py +184 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/agents/streaming.py +0 -10
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/cli/commands.py +28 -10
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/cli/database.py +1 -1
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/config/database.py +25 -3
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/database/connection.py +129 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/database/schema.py +90 -66
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/uv.lock +117 -1
- sqlsaber-0.2.0/src/sqlsaber/agents/base.py +0 -67
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/.github/workflows/publish.yml +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/.gitignore +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/.python-version +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/CLAUDE.md +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/LICENSE +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/README.md +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/pytest.ini +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/__init__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/__main__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/agents/__init__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/cli/__init__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/cli/display.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/cli/interactive.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/cli/memory.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/cli/models.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/cli/streaming.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/config/__init__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/config/api_keys.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/config/settings.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/database/__init__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/memory/__init__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/memory/manager.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/memory/storage.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/models/__init__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/models/events.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/src/sqlsaber/models/types.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/tests/__init__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/tests/conftest.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/tests/test_cli/__init__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/tests/test_cli/test_commands.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/tests/test_config/__init__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/tests/test_config/test_database.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/tests/test_config/test_settings.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/tests/test_database/__init__.py +0 -0
- {sqlsaber-0.2.0 → sqlsaber-0.3.0}/tests/test_database/test_connection.py +0 -0
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
All notable changes to SQLSaber will be documented in this file.
|
|
4
|
+
|
|
5
|
+
## [Unreleased]
|
|
6
|
+
|
|
7
|
+
## [0.3.0] - 2025-06-25
|
|
8
|
+
|
|
9
|
+
### Added
|
|
10
|
+
|
|
11
|
+
- Support for CSV files as a database option: `saber query -d mydata.csv`
|
|
12
|
+
|
|
13
|
+
### Changed
|
|
14
|
+
|
|
15
|
+
- Extracted tools to BaseSQLAgent for better inheritance across SQLAgents
|
|
16
|
+
|
|
17
|
+
### Fixed
|
|
18
|
+
|
|
19
|
+
- Fixed getting row counts for SQLite
|
|
20
|
+
|
|
21
|
+
## [0.2.0] - 2025-06-24
|
|
22
|
+
|
|
23
|
+
### Added
|
|
24
|
+
|
|
25
|
+
- SSL support for database connections during configuration
|
|
26
|
+
- Memory feature similar to Claude Code
|
|
27
|
+
- Support for SQLite and MySQL databases
|
|
28
|
+
- Model configuration (configure, select, set, reset) - Anthropic models only
|
|
29
|
+
- Comprehensive database command to securely store multiple database connection info
|
|
30
|
+
- API key storage using keyring for security
|
|
31
|
+
- Interactive questionary for all user interactions
|
|
32
|
+
- Test suite implementation
|
|
33
|
+
|
|
34
|
+
### Changed
|
|
35
|
+
|
|
36
|
+
- Package renamed from original name to sqlsaber
|
|
37
|
+
- Better configuration handling
|
|
38
|
+
- Simplified CLI interface
|
|
39
|
+
- Refactored query stream function into smaller functions
|
|
40
|
+
- Interactive markup cleanup
|
|
41
|
+
- Extracted table display functionality
|
|
42
|
+
- Refactored and cleaned up codebase structure
|
|
43
|
+
|
|
44
|
+
### Fixed
|
|
45
|
+
|
|
46
|
+
- Fixed list_tables tool functionality
|
|
47
|
+
- Fixed introspect schema tool
|
|
48
|
+
- Fixed minor type checking errors
|
|
49
|
+
- Check before adding new database to prevent duplicates
|
|
50
|
+
|
|
51
|
+
### Removed
|
|
52
|
+
|
|
53
|
+
- Removed write support completely for security
|
|
54
|
+
|
|
55
|
+
## [0.1.0] - 2025-06-19
|
|
56
|
+
|
|
57
|
+
### Added
|
|
58
|
+
|
|
59
|
+
- First working version of SQLSaber
|
|
60
|
+
- Streaming tool response and status messages
|
|
61
|
+
- Schema introspection with table listing
|
|
62
|
+
- Result row streaming as agent works
|
|
63
|
+
- Database connection and query capabilities
|
|
64
|
+
- Added publish workflow
|
|
65
|
+
- Created documentation and README
|
|
66
|
+
- Added CLAUDE.md for development instructions
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sqlsaber
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: SQLSaber - Agentic SQL assistant like Claude Code
|
|
5
5
|
License-File: LICENSE
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -10,6 +10,7 @@ Requires-Dist: anthropic>=0.54.0
|
|
|
10
10
|
Requires-Dist: asyncpg>=0.30.0
|
|
11
11
|
Requires-Dist: httpx>=0.28.1
|
|
12
12
|
Requires-Dist: keyring>=25.6.0
|
|
13
|
+
Requires-Dist: pandas>=2.0.0
|
|
13
14
|
Requires-Dist: platformdirs>=4.0.0
|
|
14
15
|
Requires-Dist: questionary>=2.1.0
|
|
15
16
|
Requires-Dist: rich>=13.7.0
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "sqlsaber"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.3.0"
|
|
4
4
|
description = "SQLSaber - Agentic SQL assistant like Claude Code"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.12"
|
|
@@ -15,6 +15,7 @@ dependencies = [
|
|
|
15
15
|
"httpx>=0.28.1",
|
|
16
16
|
"aiomysql>=0.2.0",
|
|
17
17
|
"aiosqlite>=0.21.0",
|
|
18
|
+
"pandas>=2.0.0",
|
|
18
19
|
]
|
|
19
20
|
|
|
20
21
|
[tool.uv]
|
|
@@ -11,13 +11,7 @@ from sqlsaber.agents.streaming import (
|
|
|
11
11
|
build_tool_result_block,
|
|
12
12
|
)
|
|
13
13
|
from sqlsaber.config.settings import Config
|
|
14
|
-
from sqlsaber.database.connection import
|
|
15
|
-
BaseDatabaseConnection,
|
|
16
|
-
MySQLConnection,
|
|
17
|
-
PostgreSQLConnection,
|
|
18
|
-
SQLiteConnection,
|
|
19
|
-
)
|
|
20
|
-
from sqlsaber.database.schema import SchemaManager
|
|
14
|
+
from sqlsaber.database.connection import BaseDatabaseConnection
|
|
21
15
|
from sqlsaber.memory.manager import MemoryManager
|
|
22
16
|
from sqlsaber.models.events import StreamEvent
|
|
23
17
|
from sqlsaber.models.types import ToolDefinition
|
|
@@ -36,7 +30,6 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
36
30
|
|
|
37
31
|
self.client = AsyncAnthropic(api_key=config.api_key)
|
|
38
32
|
self.model = config.model_name.replace("anthropic:", "")
|
|
39
|
-
self.schema_manager = SchemaManager(db_connection)
|
|
40
33
|
|
|
41
34
|
self.database_name = database_name
|
|
42
35
|
self.memory_manager = MemoryManager()
|
|
@@ -94,17 +87,6 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
94
87
|
# Build system prompt with memories if available
|
|
95
88
|
self.system_prompt = self._build_system_prompt()
|
|
96
89
|
|
|
97
|
-
def _get_database_type_name(self) -> str:
|
|
98
|
-
"""Get the human-readable database type name."""
|
|
99
|
-
if isinstance(self.db, PostgreSQLConnection):
|
|
100
|
-
return "PostgreSQL"
|
|
101
|
-
elif isinstance(self.db, MySQLConnection):
|
|
102
|
-
return "MySQL"
|
|
103
|
-
elif isinstance(self.db, SQLiteConnection):
|
|
104
|
-
return "SQLite"
|
|
105
|
-
else:
|
|
106
|
-
return "database" # Fallback
|
|
107
|
-
|
|
108
90
|
def _build_system_prompt(self) -> str:
|
|
109
91
|
"""Build system prompt with optional memory context."""
|
|
110
92
|
db_type = self._get_database_type_name()
|
|
@@ -152,109 +134,33 @@ Guidelines:
|
|
|
152
134
|
self.system_prompt = self._build_system_prompt()
|
|
153
135
|
return memory.id
|
|
154
136
|
|
|
155
|
-
async def introspect_schema(self, table_pattern: Optional[str] = None) -> str:
|
|
156
|
-
"""Introspect database schema to understand table structures."""
|
|
157
|
-
try:
|
|
158
|
-
# Pass table_pattern to get_schema_info for efficient filtering at DB level
|
|
159
|
-
schema_info = await self.schema_manager.get_schema_info(table_pattern)
|
|
160
|
-
|
|
161
|
-
# Format the schema information
|
|
162
|
-
formatted_info = {}
|
|
163
|
-
for table_name, table_info in schema_info.items():
|
|
164
|
-
formatted_info[table_name] = {
|
|
165
|
-
"columns": {
|
|
166
|
-
col_name: {
|
|
167
|
-
"type": col_info["data_type"],
|
|
168
|
-
"nullable": col_info["nullable"],
|
|
169
|
-
"default": col_info["default"],
|
|
170
|
-
}
|
|
171
|
-
for col_name, col_info in table_info["columns"].items()
|
|
172
|
-
},
|
|
173
|
-
"primary_keys": table_info["primary_keys"],
|
|
174
|
-
"foreign_keys": [
|
|
175
|
-
f"{fk['column']} -> {fk['references']['table']}.{fk['references']['column']}"
|
|
176
|
-
for fk in table_info["foreign_keys"]
|
|
177
|
-
],
|
|
178
|
-
}
|
|
179
|
-
|
|
180
|
-
return json.dumps(formatted_info)
|
|
181
|
-
except Exception as e:
|
|
182
|
-
return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
|
|
183
|
-
|
|
184
|
-
async def list_tables(self) -> str:
|
|
185
|
-
"""List all tables in the database with basic information."""
|
|
186
|
-
try:
|
|
187
|
-
tables_info = await self.schema_manager.list_tables()
|
|
188
|
-
return json.dumps(tables_info)
|
|
189
|
-
except Exception as e:
|
|
190
|
-
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
191
|
-
|
|
192
137
|
async def execute_sql(self, query: str, limit: Optional[int] = 100) -> str:
|
|
193
|
-
"""Execute a SQL query against the database."""
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
write_error = self._validate_write_operation(query)
|
|
197
|
-
if write_error:
|
|
198
|
-
return json.dumps(
|
|
199
|
-
{
|
|
200
|
-
"error": write_error,
|
|
201
|
-
}
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
# Add LIMIT if not present and it's a SELECT query
|
|
205
|
-
query = self._add_limit_to_query(query, limit)
|
|
206
|
-
|
|
207
|
-
# Execute the query (wrapped in a transaction for safety)
|
|
208
|
-
results = await self.db.execute_query(query)
|
|
209
|
-
|
|
210
|
-
# Format results - but also store the actual data
|
|
211
|
-
actual_limit = limit if limit is not None else len(results)
|
|
212
|
-
self._last_results = results[:actual_limit]
|
|
213
|
-
self._last_query = query
|
|
138
|
+
"""Execute a SQL query against the database with streaming support."""
|
|
139
|
+
# Call parent implementation for core functionality
|
|
140
|
+
result = await super().execute_sql(query, limit)
|
|
214
141
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
except Exception as e:
|
|
225
|
-
error_msg = str(e)
|
|
226
|
-
|
|
227
|
-
# Provide helpful error messages
|
|
228
|
-
suggestions = []
|
|
229
|
-
if "column" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
230
|
-
suggestions.append(
|
|
231
|
-
"Check column names using the schema introspection tool"
|
|
232
|
-
)
|
|
233
|
-
elif "table" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
234
|
-
suggestions.append(
|
|
235
|
-
"Check table names using the schema introspection tool"
|
|
236
|
-
)
|
|
237
|
-
elif "syntax error" in error_msg.lower():
|
|
238
|
-
suggestions.append(
|
|
239
|
-
"Review SQL syntax, especially JOIN conditions and WHERE clauses"
|
|
142
|
+
# Parse result to extract data for streaming (AnthropicSQLAgent specific)
|
|
143
|
+
try:
|
|
144
|
+
result_data = json.loads(result)
|
|
145
|
+
if result_data.get("success") and "results" in result_data:
|
|
146
|
+
# Store results for streaming
|
|
147
|
+
actual_limit = (
|
|
148
|
+
limit if limit is not None else len(result_data["results"])
|
|
240
149
|
)
|
|
150
|
+
self._last_results = result_data["results"][:actual_limit]
|
|
151
|
+
self._last_query = query
|
|
152
|
+
except (json.JSONDecodeError, KeyError):
|
|
153
|
+
# If we can't parse the result, just continue without storing
|
|
154
|
+
pass
|
|
241
155
|
|
|
242
|
-
|
|
156
|
+
return result
|
|
243
157
|
|
|
244
158
|
async def process_tool_call(
|
|
245
159
|
self, tool_name: str, tool_input: Dict[str, Any]
|
|
246
160
|
) -> str:
|
|
247
161
|
"""Process a tool call and return the result."""
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
elif tool_name == "introspect_schema":
|
|
251
|
-
return await self.introspect_schema(tool_input.get("table_pattern"))
|
|
252
|
-
elif tool_name == "execute_sql":
|
|
253
|
-
return await self.execute_sql(
|
|
254
|
-
tool_input["query"], tool_input.get("limit", 100)
|
|
255
|
-
)
|
|
256
|
-
else:
|
|
257
|
-
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
162
|
+
# Use parent implementation for core tools
|
|
163
|
+
return await super().process_tool_call(tool_name, tool_input)
|
|
258
164
|
|
|
259
165
|
async def _process_stream_events(
|
|
260
166
|
self, stream, content_blocks: List[Dict], tool_use_blocks: List[Dict]
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Abstract base class for SQL agents."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from sqlsaber.database.connection import (
|
|
8
|
+
BaseDatabaseConnection,
|
|
9
|
+
CSVConnection,
|
|
10
|
+
MySQLConnection,
|
|
11
|
+
PostgreSQLConnection,
|
|
12
|
+
SQLiteConnection,
|
|
13
|
+
)
|
|
14
|
+
from sqlsaber.database.schema import SchemaManager
|
|
15
|
+
from sqlsaber.models.events import StreamEvent
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseSQLAgent(ABC):
|
|
19
|
+
"""Abstract base class for SQL agents."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, db_connection: BaseDatabaseConnection):
|
|
22
|
+
self.db = db_connection
|
|
23
|
+
self.schema_manager = SchemaManager(db_connection)
|
|
24
|
+
self.conversation_history: List[Dict[str, Any]] = []
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
async def query_stream(
|
|
28
|
+
self, user_query: str, use_history: bool = True
|
|
29
|
+
) -> AsyncIterator[StreamEvent]:
|
|
30
|
+
"""Process a user query and stream responses."""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
def clear_history(self):
|
|
34
|
+
"""Clear conversation history."""
|
|
35
|
+
self.conversation_history = []
|
|
36
|
+
|
|
37
|
+
def _get_database_type_name(self) -> str:
|
|
38
|
+
"""Get the human-readable database type name."""
|
|
39
|
+
if isinstance(self.db, PostgreSQLConnection):
|
|
40
|
+
return "PostgreSQL"
|
|
41
|
+
elif isinstance(self.db, MySQLConnection):
|
|
42
|
+
return "MySQL"
|
|
43
|
+
elif isinstance(self.db, SQLiteConnection):
|
|
44
|
+
return "SQLite"
|
|
45
|
+
elif isinstance(self.db, CSVConnection):
|
|
46
|
+
return "SQLite" # we convert csv to in-memory sqlite
|
|
47
|
+
else:
|
|
48
|
+
return "database" # Fallback
|
|
49
|
+
|
|
50
|
+
async def introspect_schema(self, table_pattern: Optional[str] = None) -> str:
|
|
51
|
+
"""Introspect database schema to understand table structures."""
|
|
52
|
+
try:
|
|
53
|
+
# Pass table_pattern to get_schema_info for efficient filtering at DB level
|
|
54
|
+
schema_info = await self.schema_manager.get_schema_info(table_pattern)
|
|
55
|
+
|
|
56
|
+
# Format the schema information
|
|
57
|
+
formatted_info = {}
|
|
58
|
+
for table_name, table_info in schema_info.items():
|
|
59
|
+
formatted_info[table_name] = {
|
|
60
|
+
"columns": {
|
|
61
|
+
col_name: {
|
|
62
|
+
"type": col_info["data_type"],
|
|
63
|
+
"nullable": col_info["nullable"],
|
|
64
|
+
"default": col_info["default"],
|
|
65
|
+
}
|
|
66
|
+
for col_name, col_info in table_info["columns"].items()
|
|
67
|
+
},
|
|
68
|
+
"primary_keys": table_info["primary_keys"],
|
|
69
|
+
"foreign_keys": [
|
|
70
|
+
f"{fk['column']} -> {fk['references']['table']}.{fk['references']['column']}"
|
|
71
|
+
for fk in table_info["foreign_keys"]
|
|
72
|
+
],
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
return json.dumps(formatted_info)
|
|
76
|
+
except Exception as e:
|
|
77
|
+
return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
|
|
78
|
+
|
|
79
|
+
async def list_tables(self) -> str:
|
|
80
|
+
"""List all tables in the database with basic information."""
|
|
81
|
+
try:
|
|
82
|
+
tables_info = await self.schema_manager.list_tables()
|
|
83
|
+
return json.dumps(tables_info)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
86
|
+
|
|
87
|
+
async def execute_sql(self, query: str, limit: Optional[int] = 100) -> str:
|
|
88
|
+
"""Execute a SQL query against the database."""
|
|
89
|
+
try:
|
|
90
|
+
# Security check - only allow SELECT queries unless write is enabled
|
|
91
|
+
write_error = self._validate_write_operation(query)
|
|
92
|
+
if write_error:
|
|
93
|
+
return json.dumps(
|
|
94
|
+
{
|
|
95
|
+
"error": write_error,
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Add LIMIT if not present and it's a SELECT query
|
|
100
|
+
query = self._add_limit_to_query(query, limit)
|
|
101
|
+
|
|
102
|
+
# Execute the query (wrapped in a transaction for safety)
|
|
103
|
+
results = await self.db.execute_query(query)
|
|
104
|
+
|
|
105
|
+
# Format results
|
|
106
|
+
actual_limit = limit if limit is not None else len(results)
|
|
107
|
+
|
|
108
|
+
return json.dumps(
|
|
109
|
+
{
|
|
110
|
+
"success": True,
|
|
111
|
+
"row_count": len(results),
|
|
112
|
+
"results": results[:actual_limit], # Extra safety for limit
|
|
113
|
+
"truncated": len(results) > actual_limit,
|
|
114
|
+
}
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
except Exception as e:
|
|
118
|
+
error_msg = str(e)
|
|
119
|
+
|
|
120
|
+
# Provide helpful error messages
|
|
121
|
+
suggestions = []
|
|
122
|
+
if "column" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
123
|
+
suggestions.append(
|
|
124
|
+
"Check column names using the schema introspection tool"
|
|
125
|
+
)
|
|
126
|
+
elif "table" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
127
|
+
suggestions.append(
|
|
128
|
+
"Check table names using the schema introspection tool"
|
|
129
|
+
)
|
|
130
|
+
elif "syntax error" in error_msg.lower():
|
|
131
|
+
suggestions.append(
|
|
132
|
+
"Review SQL syntax, especially JOIN conditions and WHERE clauses"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return json.dumps({"error": error_msg, "suggestions": suggestions})
|
|
136
|
+
|
|
137
|
+
async def process_tool_call(
|
|
138
|
+
self, tool_name: str, tool_input: Dict[str, Any]
|
|
139
|
+
) -> str:
|
|
140
|
+
"""Process a tool call and return the result."""
|
|
141
|
+
if tool_name == "list_tables":
|
|
142
|
+
return await self.list_tables()
|
|
143
|
+
elif tool_name == "introspect_schema":
|
|
144
|
+
return await self.introspect_schema(tool_input.get("table_pattern"))
|
|
145
|
+
elif tool_name == "execute_sql":
|
|
146
|
+
return await self.execute_sql(
|
|
147
|
+
tool_input["query"], tool_input.get("limit", 100)
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
151
|
+
|
|
152
|
+
def _validate_write_operation(self, query: str) -> Optional[str]:
|
|
153
|
+
"""Validate if a write operation is allowed.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
None if operation is allowed, error message if not allowed.
|
|
157
|
+
"""
|
|
158
|
+
query_upper = query.strip().upper()
|
|
159
|
+
|
|
160
|
+
# Check for write operations
|
|
161
|
+
write_keywords = [
|
|
162
|
+
"INSERT",
|
|
163
|
+
"UPDATE",
|
|
164
|
+
"DELETE",
|
|
165
|
+
"DROP",
|
|
166
|
+
"CREATE",
|
|
167
|
+
"ALTER",
|
|
168
|
+
"TRUNCATE",
|
|
169
|
+
]
|
|
170
|
+
is_write_query = any(query_upper.startswith(kw) for kw in write_keywords)
|
|
171
|
+
|
|
172
|
+
if is_write_query:
|
|
173
|
+
return (
|
|
174
|
+
"Write operations are not allowed. Only SELECT queries are permitted."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
def _add_limit_to_query(self, query: str, limit: int = 100) -> str:
|
|
180
|
+
"""Add LIMIT clause to SELECT queries if not present."""
|
|
181
|
+
query_upper = query.strip().upper()
|
|
182
|
+
if query_upper.startswith("SELECT") and "LIMIT" not in query_upper:
|
|
183
|
+
return f"{query.rstrip(';')} LIMIT {limit};"
|
|
184
|
+
return query
|
|
@@ -14,13 +14,3 @@ class StreamingResponse:
|
|
|
14
14
|
def build_tool_result_block(tool_use_id: str, content: str) -> Dict[str, Any]:
|
|
15
15
|
"""Build a tool result block for the conversation."""
|
|
16
16
|
return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def extract_sql_from_text(text: str) -> str:
|
|
20
|
-
"""Extract SQL query from markdown-formatted text."""
|
|
21
|
-
if "```sql" in text:
|
|
22
|
-
sql_start = text.find("```sql") + 6
|
|
23
|
-
sql_end = text.find("```", sql_start)
|
|
24
|
-
if sql_end > sql_start:
|
|
25
|
-
return text[sql_start:sql_end].strip()
|
|
26
|
-
return ""
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""CLI command definitions and handlers."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
from pathlib import Path
|
|
4
5
|
from typing import Optional
|
|
5
6
|
|
|
6
7
|
import typer
|
|
@@ -62,15 +63,31 @@ def query(
|
|
|
62
63
|
"""Run a query against the database or start interactive mode."""
|
|
63
64
|
|
|
64
65
|
async def run_session():
|
|
65
|
-
# Get database configuration
|
|
66
|
+
# Get database configuration or handle direct CSV file
|
|
66
67
|
if database:
|
|
67
|
-
|
|
68
|
-
if
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
68
|
+
# Check if this is a direct CSV file path
|
|
69
|
+
if database.endswith(".csv"):
|
|
70
|
+
csv_path = Path(database).expanduser().resolve()
|
|
71
|
+
if not csv_path.exists():
|
|
72
|
+
console.print(
|
|
73
|
+
f"[bold red]Error:[/bold red] CSV file '{database}' not found."
|
|
74
|
+
)
|
|
75
|
+
raise typer.Exit(1)
|
|
76
|
+
connection_string = f"csv:///{csv_path}"
|
|
77
|
+
db_name = csv_path.stem
|
|
78
|
+
else:
|
|
79
|
+
# Look up configured database connection
|
|
80
|
+
db_config = config_manager.get_database(database)
|
|
81
|
+
if not db_config:
|
|
82
|
+
console.print(
|
|
83
|
+
f"[bold red]Error:[/bold red] Database connection '{database}' not found."
|
|
84
|
+
)
|
|
85
|
+
console.print(
|
|
86
|
+
"Use 'sqlsaber db list' to see available connections."
|
|
87
|
+
)
|
|
88
|
+
raise typer.Exit(1)
|
|
89
|
+
connection_string = db_config.to_connection_string()
|
|
90
|
+
db_name = db_config.name
|
|
74
91
|
else:
|
|
75
92
|
db_config = config_manager.get_default_database()
|
|
76
93
|
if not db_config:
|
|
@@ -81,10 +98,11 @@ def query(
|
|
|
81
98
|
"Use 'sqlsaber db add <name>' to add a database connection."
|
|
82
99
|
)
|
|
83
100
|
raise typer.Exit(1)
|
|
101
|
+
connection_string = db_config.to_connection_string()
|
|
102
|
+
db_name = db_config.name
|
|
84
103
|
|
|
85
104
|
# Create database connection
|
|
86
105
|
try:
|
|
87
|
-
connection_string = db_config.to_connection_string()
|
|
88
106
|
db_conn = DatabaseConnection(connection_string)
|
|
89
107
|
except Exception as e:
|
|
90
108
|
console.print(
|
|
@@ -93,7 +111,7 @@ def query(
|
|
|
93
111
|
raise typer.Exit(1)
|
|
94
112
|
|
|
95
113
|
# Create agent instance with database name for memory context
|
|
96
|
-
agent = AnthropicSQLAgent(db_conn,
|
|
114
|
+
agent = AnthropicSQLAgent(db_conn, db_name)
|
|
97
115
|
|
|
98
116
|
try:
|
|
99
117
|
if query_text:
|
|
@@ -75,7 +75,7 @@ def add_database(
|
|
|
75
75
|
if type == "sqlite":
|
|
76
76
|
# SQLite only needs database path
|
|
77
77
|
database = database or questionary.path("Database file path:").ask()
|
|
78
|
-
database = str(Path(database).expanduser())
|
|
78
|
+
database = str(Path(database).expanduser().resolve())
|
|
79
79
|
host = "localhost"
|
|
80
80
|
port = 0
|
|
81
81
|
username = "sqlite"
|
|
@@ -4,12 +4,12 @@ import json
|
|
|
4
4
|
import os
|
|
5
5
|
import platform
|
|
6
6
|
import stat
|
|
7
|
-
import keyring
|
|
8
7
|
from dataclasses import dataclass
|
|
9
8
|
from pathlib import Path
|
|
10
|
-
from typing import Dict, List, Optional
|
|
9
|
+
from typing import Any, Dict, List, Optional
|
|
11
10
|
from urllib.parse import quote_plus
|
|
12
11
|
|
|
12
|
+
import keyring
|
|
13
13
|
import platformdirs
|
|
14
14
|
|
|
15
15
|
|
|
@@ -18,7 +18,7 @@ class DatabaseConfig:
|
|
|
18
18
|
"""Database connection configuration."""
|
|
19
19
|
|
|
20
20
|
name: str
|
|
21
|
-
type: str # postgresql, mysql, sqlite
|
|
21
|
+
type: str # postgresql, mysql, sqlite, csv
|
|
22
22
|
host: Optional[str]
|
|
23
23
|
port: Optional[int]
|
|
24
24
|
database: str
|
|
@@ -90,6 +90,28 @@ class DatabaseConfig:
|
|
|
90
90
|
|
|
91
91
|
elif self.type == "sqlite":
|
|
92
92
|
return f"sqlite:///{self.database}"
|
|
93
|
+
elif self.type == "csv":
|
|
94
|
+
# For CSV files, database field contains the file path
|
|
95
|
+
base_url = f"csv:///{self.database}"
|
|
96
|
+
|
|
97
|
+
# Add CSV-specific parameters if they exist in schema field
|
|
98
|
+
if self.schema:
|
|
99
|
+
# Schema field can contain CSV options in JSON format
|
|
100
|
+
try:
|
|
101
|
+
csv_options = json.loads(self.schema)
|
|
102
|
+
params = []
|
|
103
|
+
if "delimiter" in csv_options:
|
|
104
|
+
params.append(f"delimiter={csv_options['delimiter']}")
|
|
105
|
+
if "encoding" in csv_options:
|
|
106
|
+
params.append(f"encoding={csv_options['encoding']}")
|
|
107
|
+
if "header" in csv_options:
|
|
108
|
+
params.append(f"header={str(csv_options['header']).lower()}")
|
|
109
|
+
|
|
110
|
+
if params:
|
|
111
|
+
return f"{base_url}?{'&'.join(params)}"
|
|
112
|
+
except (json.JSONDecodeError, KeyError):
|
|
113
|
+
pass
|
|
114
|
+
return base_url
|
|
93
115
|
else:
|
|
94
116
|
raise ValueError(f"Unsupported database type: {self.type}")
|
|
95
117
|
|