sqlsaber 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -2
- sqlsaber/agents/base.py +1 -1
- sqlsaber/agents/mcp.py +1 -1
- sqlsaber/agents/pydantic_ai_agent.py +207 -135
- sqlsaber/cli/commands.py +11 -28
- sqlsaber/cli/completers.py +2 -0
- sqlsaber/cli/database.py +1 -1
- sqlsaber/cli/display.py +29 -9
- sqlsaber/cli/interactive.py +22 -15
- sqlsaber/cli/streaming.py +15 -17
- sqlsaber/cli/threads.py +10 -6
- sqlsaber/config/settings.py +25 -2
- sqlsaber/database/__init__.py +55 -1
- sqlsaber/database/base.py +124 -0
- sqlsaber/database/csv.py +133 -0
- sqlsaber/database/duckdb.py +313 -0
- sqlsaber/database/mysql.py +345 -0
- sqlsaber/database/postgresql.py +328 -0
- sqlsaber/database/schema.py +66 -963
- sqlsaber/database/sqlite.py +258 -0
- sqlsaber/mcp/mcp.py +1 -1
- sqlsaber/tools/sql_tools.py +1 -1
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/METADATA +43 -9
- sqlsaber-0.26.0.dist-info/RECORD +52 -0
- sqlsaber/database/connection.py +0 -535
- sqlsaber-0.25.0.dist-info/RECORD +0 -47
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/interactive.py
CHANGED
|
@@ -3,13 +3,13 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from textwrap import dedent
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
6
7
|
|
|
7
8
|
import platformdirs
|
|
8
9
|
from prompt_toolkit import PromptSession
|
|
9
10
|
from prompt_toolkit.history import FileHistory
|
|
10
11
|
from prompt_toolkit.patch_stdout import patch_stdout
|
|
11
12
|
from prompt_toolkit.styles import Style
|
|
12
|
-
from pydantic_ai import Agent
|
|
13
13
|
from rich.console import Console
|
|
14
14
|
from rich.markdown import Markdown
|
|
15
15
|
from rich.panel import Panel
|
|
@@ -21,7 +21,7 @@ from sqlsaber.cli.completers import (
|
|
|
21
21
|
)
|
|
22
22
|
from sqlsaber.cli.display import DisplayManager
|
|
23
23
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
24
|
-
from sqlsaber.database
|
|
24
|
+
from sqlsaber.database import (
|
|
25
25
|
CSVConnection,
|
|
26
26
|
DuckDBConnection,
|
|
27
27
|
MySQLConnection,
|
|
@@ -31,6 +31,9 @@ from sqlsaber.database.connection import (
|
|
|
31
31
|
from sqlsaber.database.schema import SchemaManager
|
|
32
32
|
from sqlsaber.threads import ThreadStorage
|
|
33
33
|
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
|
|
36
|
+
|
|
34
37
|
|
|
35
38
|
def bottom_toolbar():
|
|
36
39
|
return [
|
|
@@ -55,7 +58,7 @@ class InteractiveSession:
|
|
|
55
58
|
def __init__(
|
|
56
59
|
self,
|
|
57
60
|
console: Console,
|
|
58
|
-
|
|
61
|
+
sqlsaber_agent: "SQLSaberAgent",
|
|
59
62
|
db_conn,
|
|
60
63
|
database_name: str,
|
|
61
64
|
*,
|
|
@@ -63,7 +66,7 @@ class InteractiveSession:
|
|
|
63
66
|
initial_history: list | None = None,
|
|
64
67
|
):
|
|
65
68
|
self.console = console
|
|
66
|
-
self.
|
|
69
|
+
self.sqlsaber_agent = sqlsaber_agent
|
|
67
70
|
self.db_conn = db_conn
|
|
68
71
|
self.database_name = database_name
|
|
69
72
|
self.display = DisplayManager(console)
|
|
@@ -176,7 +179,7 @@ class InteractiveSession:
|
|
|
176
179
|
query_task = asyncio.create_task(
|
|
177
180
|
self.streaming_handler.execute_streaming_query(
|
|
178
181
|
user_query,
|
|
179
|
-
self.
|
|
182
|
+
self.sqlsaber_agent,
|
|
180
183
|
self.cancellation_token,
|
|
181
184
|
self.message_history,
|
|
182
185
|
)
|
|
@@ -191,11 +194,6 @@ class InteractiveSession:
|
|
|
191
194
|
# Use all_messages() so the system prompt and all prior turns are preserved
|
|
192
195
|
self.message_history = run_result.all_messages()
|
|
193
196
|
|
|
194
|
-
# Extract title (first user prompt) and model name
|
|
195
|
-
if not self._thread_id:
|
|
196
|
-
title = user_query
|
|
197
|
-
model_name = self.agent.model.model_name
|
|
198
|
-
|
|
199
197
|
# Persist snapshot to thread storage (create or overwrite)
|
|
200
198
|
self._thread_id = await self._threads.save_snapshot(
|
|
201
199
|
messages_json=run_result.all_messages_json(),
|
|
@@ -206,8 +204,8 @@ class InteractiveSession:
|
|
|
206
204
|
if self.first_message:
|
|
207
205
|
await self._threads.save_metadata(
|
|
208
206
|
thread_id=self._thread_id,
|
|
209
|
-
title=
|
|
210
|
-
model_name=model_name,
|
|
207
|
+
title=user_query,
|
|
208
|
+
model_name=self.sqlsaber_agent.agent.model.model_name,
|
|
211
209
|
)
|
|
212
210
|
except Exception:
|
|
213
211
|
pass
|
|
@@ -269,6 +267,17 @@ class InteractiveSession:
|
|
|
269
267
|
self._thread_id = None
|
|
270
268
|
continue
|
|
271
269
|
|
|
270
|
+
# Thinking commands
|
|
271
|
+
if user_query == "/thinking on":
|
|
272
|
+
self.sqlsaber_agent.set_thinking(enabled=True)
|
|
273
|
+
self.console.print("[green]✓ Thinking enabled[/green]\n")
|
|
274
|
+
continue
|
|
275
|
+
|
|
276
|
+
if user_query == "/thinking off":
|
|
277
|
+
self.sqlsaber_agent.set_thinking(enabled=False)
|
|
278
|
+
self.console.print("[green]✓ Thinking disabled[/green]\n")
|
|
279
|
+
continue
|
|
280
|
+
|
|
272
281
|
if memory_text := user_query.strip():
|
|
273
282
|
# Check if query starts with # for memory addition
|
|
274
283
|
if memory_text.startswith("#"):
|
|
@@ -276,9 +285,7 @@ class InteractiveSession:
|
|
|
276
285
|
if memory_content:
|
|
277
286
|
# Add memory via the agent's memory manager
|
|
278
287
|
try:
|
|
279
|
-
mm =
|
|
280
|
-
self.agent, "_sqlsaber_memory_manager", None
|
|
281
|
-
)
|
|
288
|
+
mm = self.sqlsaber_agent.memory_manager
|
|
282
289
|
if mm and self.database_name:
|
|
283
290
|
memory = mm.add_memory(
|
|
284
291
|
self.database_name, memory_content
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -8,9 +8,9 @@ rendered via DisplayManager helpers.
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import json
|
|
10
10
|
from functools import singledispatchmethod
|
|
11
|
-
from typing import AsyncIterable
|
|
11
|
+
from typing import TYPE_CHECKING, AsyncIterable
|
|
12
12
|
|
|
13
|
-
from pydantic_ai import
|
|
13
|
+
from pydantic_ai import RunContext
|
|
14
14
|
from pydantic_ai.messages import (
|
|
15
15
|
AgentStreamEvent,
|
|
16
16
|
FunctionToolCallEvent,
|
|
@@ -26,6 +26,9 @@ from rich.console import Console
|
|
|
26
26
|
|
|
27
27
|
from sqlsaber.cli.display import DisplayManager
|
|
28
28
|
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
|
|
31
|
+
|
|
29
32
|
|
|
30
33
|
class StreamingQueryHandler:
|
|
31
34
|
"""
|
|
@@ -130,7 +133,7 @@ class StreamingQueryHandler:
|
|
|
130
133
|
async def execute_streaming_query(
|
|
131
134
|
self,
|
|
132
135
|
user_query: str,
|
|
133
|
-
|
|
136
|
+
sqlsaber_agent: "SQLSaberAgent",
|
|
134
137
|
cancellation_token: asyncio.Event | None = None,
|
|
135
138
|
message_history: list | None = None,
|
|
136
139
|
):
|
|
@@ -139,21 +142,16 @@ class StreamingQueryHandler:
|
|
|
139
142
|
try:
|
|
140
143
|
# If Anthropic OAuth, inject SQLsaber instructions before the first user prompt
|
|
141
144
|
prepared_prompt: str | list[str] = user_query
|
|
142
|
-
is_oauth = bool(getattr(agent, "_sqlsaber_is_oauth", False))
|
|
143
145
|
no_history = not message_history
|
|
144
|
-
if is_oauth and no_history:
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
db_type = getattr(agent, "_sqlsaber_db_type", "database")
|
|
148
|
-
db_name = getattr(agent, "_sqlsaber_database_name", None)
|
|
149
|
-
instructions = (
|
|
150
|
-
ib.build_instructions(db_type=db_type) if ib is not None else ""
|
|
151
|
-
)
|
|
152
|
-
mem = (
|
|
153
|
-
mm.format_memories_for_prompt(db_name)
|
|
154
|
-
if (mm is not None and db_name)
|
|
155
|
-
else ""
|
|
146
|
+
if sqlsaber_agent.is_oauth and no_history:
|
|
147
|
+
instructions = sqlsaber_agent.instruction_builder.build_instructions(
|
|
148
|
+
db_type=sqlsaber_agent.db_type
|
|
156
149
|
)
|
|
150
|
+
mem = ""
|
|
151
|
+
if sqlsaber_agent.database_name:
|
|
152
|
+
mem = sqlsaber_agent.memory_manager.format_memories_for_prompt(
|
|
153
|
+
sqlsaber_agent.database_name
|
|
154
|
+
)
|
|
157
155
|
parts = [p for p in (instructions, mem) if p and str(p).strip()]
|
|
158
156
|
if parts:
|
|
159
157
|
injected = "\n\n".join(parts)
|
|
@@ -163,7 +161,7 @@ class StreamingQueryHandler:
|
|
|
163
161
|
self.display.live.start_status("Crunching data...")
|
|
164
162
|
|
|
165
163
|
# Run the agent with our event stream handler
|
|
166
|
-
run = await agent.run(
|
|
164
|
+
run = await sqlsaber_agent.agent.run(
|
|
167
165
|
prepared_prompt,
|
|
168
166
|
message_history=message_history,
|
|
169
167
|
event_stream_handler=self._event_stream_handler,
|
sqlsaber/cli/threads.py
CHANGED
|
@@ -148,7 +148,9 @@ def _render_transcript(
|
|
|
148
148
|
)
|
|
149
149
|
else:
|
|
150
150
|
if is_redirected:
|
|
151
|
-
console.print(
|
|
151
|
+
console.print(
|
|
152
|
+
f"**Tool result ({name}):**\n\n{content_str}\n"
|
|
153
|
+
)
|
|
152
154
|
else:
|
|
153
155
|
console.print(
|
|
154
156
|
Panel.fit(
|
|
@@ -159,7 +161,9 @@ def _render_transcript(
|
|
|
159
161
|
)
|
|
160
162
|
except Exception:
|
|
161
163
|
if is_redirected:
|
|
162
|
-
console.print(
|
|
164
|
+
console.print(
|
|
165
|
+
f"**Tool result ({name}):**\n\n{content_str}\n"
|
|
166
|
+
)
|
|
163
167
|
else:
|
|
164
168
|
console.print(
|
|
165
169
|
Panel.fit(
|
|
@@ -258,10 +262,10 @@ def resume(
|
|
|
258
262
|
|
|
259
263
|
async def _run() -> None:
|
|
260
264
|
# Lazy imports to avoid heavy modules at CLI startup
|
|
261
|
-
from sqlsaber.agents import
|
|
265
|
+
from sqlsaber.agents import SQLSaberAgent
|
|
262
266
|
from sqlsaber.cli.interactive import InteractiveSession
|
|
263
267
|
from sqlsaber.config.database import DatabaseConfigManager
|
|
264
|
-
from sqlsaber.database
|
|
268
|
+
from sqlsaber.database import DatabaseConnection
|
|
265
269
|
from sqlsaber.database.resolver import (
|
|
266
270
|
DatabaseResolutionError,
|
|
267
271
|
resolve_database,
|
|
@@ -288,7 +292,7 @@ def resume(
|
|
|
288
292
|
|
|
289
293
|
db_conn = DatabaseConnection(connection_string)
|
|
290
294
|
try:
|
|
291
|
-
|
|
295
|
+
sqlsaber_agent = SQLSaberAgent(db_conn, db_name)
|
|
292
296
|
history = await store.get_thread_messages(thread_id)
|
|
293
297
|
if console.is_terminal:
|
|
294
298
|
console.print(Panel.fit(f"Thread: {thread.id}", border_style="blue"))
|
|
@@ -297,7 +301,7 @@ def resume(
|
|
|
297
301
|
_render_transcript(console, history, None)
|
|
298
302
|
session = InteractiveSession(
|
|
299
303
|
console=console,
|
|
300
|
-
|
|
304
|
+
sqlsaber_agent=sqlsaber_agent,
|
|
301
305
|
db_conn=db_conn,
|
|
302
306
|
database_name=db_name,
|
|
303
307
|
initial_thread_id=thread_id,
|
sqlsaber/config/settings.py
CHANGED
|
@@ -46,7 +46,10 @@ class ModelConfigManager:
|
|
|
46
46
|
def _load_config(self) -> dict[str, Any]:
|
|
47
47
|
"""Load configuration from file."""
|
|
48
48
|
if not self.config_file.exists():
|
|
49
|
-
return {
|
|
49
|
+
return {
|
|
50
|
+
"model": self.DEFAULT_MODEL,
|
|
51
|
+
"thinking_enabled": False,
|
|
52
|
+
}
|
|
50
53
|
|
|
51
54
|
try:
|
|
52
55
|
with open(self.config_file, "r") as f:
|
|
@@ -54,9 +57,15 @@ class ModelConfigManager:
|
|
|
54
57
|
# Ensure we have a model set
|
|
55
58
|
if "model" not in config:
|
|
56
59
|
config["model"] = self.DEFAULT_MODEL
|
|
60
|
+
# Set defaults for thinking if not present
|
|
61
|
+
if "thinking_enabled" not in config:
|
|
62
|
+
config["thinking_enabled"] = False
|
|
57
63
|
return config
|
|
58
64
|
except (json.JSONDecodeError, IOError):
|
|
59
|
-
return {
|
|
65
|
+
return {
|
|
66
|
+
"model": self.DEFAULT_MODEL,
|
|
67
|
+
"thinking_enabled": False,
|
|
68
|
+
}
|
|
60
69
|
|
|
61
70
|
def _save_config(self, config: dict[str, Any]) -> None:
|
|
62
71
|
"""Save configuration to file."""
|
|
@@ -76,6 +85,17 @@ class ModelConfigManager:
|
|
|
76
85
|
config["model"] = model
|
|
77
86
|
self._save_config(config)
|
|
78
87
|
|
|
88
|
+
def get_thinking_enabled(self) -> bool:
|
|
89
|
+
"""Get whether thinking is enabled."""
|
|
90
|
+
config = self._load_config()
|
|
91
|
+
return config.get("thinking_enabled", False)
|
|
92
|
+
|
|
93
|
+
def set_thinking_enabled(self, enabled: bool) -> None:
|
|
94
|
+
"""Set whether thinking is enabled."""
|
|
95
|
+
config = self._load_config()
|
|
96
|
+
config["thinking_enabled"] = enabled
|
|
97
|
+
self._save_config(config)
|
|
98
|
+
|
|
79
99
|
|
|
80
100
|
class Config:
|
|
81
101
|
"""Configuration class for SQLSaber."""
|
|
@@ -86,6 +106,9 @@ class Config:
|
|
|
86
106
|
self.api_key_manager = APIKeyManager()
|
|
87
107
|
self.auth_config_manager = AuthConfigManager()
|
|
88
108
|
|
|
109
|
+
# Thinking configuration
|
|
110
|
+
self.thinking_enabled = self.model_config_manager.get_thinking_enabled()
|
|
111
|
+
|
|
89
112
|
# Authentication method (API key or Anthropic OAuth)
|
|
90
113
|
self.auth_method = self.auth_config_manager.get_auth_method()
|
|
91
114
|
|
sqlsaber/database/__init__.py
CHANGED
|
@@ -1,9 +1,63 @@
|
|
|
1
1
|
"""Database module for SQLSaber."""
|
|
2
2
|
|
|
3
|
-
from .
|
|
3
|
+
from .base import (
|
|
4
|
+
DEFAULT_QUERY_TIMEOUT,
|
|
5
|
+
BaseDatabaseConnection,
|
|
6
|
+
BaseSchemaIntrospector,
|
|
7
|
+
ColumnInfo,
|
|
8
|
+
ForeignKeyInfo,
|
|
9
|
+
IndexInfo,
|
|
10
|
+
QueryTimeoutError,
|
|
11
|
+
SchemaInfo,
|
|
12
|
+
)
|
|
13
|
+
from .csv import CSVConnection, CSVSchemaIntrospector
|
|
14
|
+
from .duckdb import DuckDBConnection, DuckDBSchemaIntrospector
|
|
15
|
+
from .mysql import MySQLConnection, MySQLSchemaIntrospector
|
|
16
|
+
from .postgresql import PostgreSQLConnection, PostgreSQLSchemaIntrospector
|
|
4
17
|
from .schema import SchemaManager
|
|
18
|
+
from .sqlite import SQLiteConnection, SQLiteSchemaIntrospector
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def DatabaseConnection(connection_string: str) -> BaseDatabaseConnection:
|
|
22
|
+
"""Factory function to create appropriate database connection based on connection string."""
|
|
23
|
+
if connection_string.startswith("postgresql://"):
|
|
24
|
+
return PostgreSQLConnection(connection_string)
|
|
25
|
+
elif connection_string.startswith("mysql://"):
|
|
26
|
+
return MySQLConnection(connection_string)
|
|
27
|
+
elif connection_string.startswith("sqlite:///"):
|
|
28
|
+
return SQLiteConnection(connection_string)
|
|
29
|
+
elif connection_string.startswith("duckdb://"):
|
|
30
|
+
return DuckDBConnection(connection_string)
|
|
31
|
+
elif connection_string.startswith("csv:///"):
|
|
32
|
+
return CSVConnection(connection_string)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Unsupported database type in connection string: {connection_string}"
|
|
36
|
+
)
|
|
37
|
+
|
|
5
38
|
|
|
6
39
|
__all__ = [
|
|
40
|
+
# Base classes and types
|
|
41
|
+
"BaseDatabaseConnection",
|
|
42
|
+
"BaseSchemaIntrospector",
|
|
43
|
+
"ColumnInfo",
|
|
44
|
+
"DEFAULT_QUERY_TIMEOUT",
|
|
45
|
+
"ForeignKeyInfo",
|
|
46
|
+
"IndexInfo",
|
|
47
|
+
"QueryTimeoutError",
|
|
48
|
+
"SchemaInfo",
|
|
49
|
+
# Concrete implementations
|
|
50
|
+
"PostgreSQLConnection",
|
|
51
|
+
"MySQLConnection",
|
|
52
|
+
"SQLiteConnection",
|
|
53
|
+
"DuckDBConnection",
|
|
54
|
+
"CSVConnection",
|
|
55
|
+
"PostgreSQLSchemaIntrospector",
|
|
56
|
+
"MySQLSchemaIntrospector",
|
|
57
|
+
"SQLiteSchemaIntrospector",
|
|
58
|
+
"DuckDBSchemaIntrospector",
|
|
59
|
+
"CSVSchemaIntrospector",
|
|
60
|
+
# Factory function and manager
|
|
7
61
|
"DatabaseConnection",
|
|
8
62
|
"SchemaManager",
|
|
9
63
|
]
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Base classes and type definitions for database connections and schema introspection."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, TypedDict
|
|
5
|
+
|
|
6
|
+
# Default query timeout to prevent runaway queries
|
|
7
|
+
DEFAULT_QUERY_TIMEOUT = 30.0 # seconds
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class QueryTimeoutError(RuntimeError):
|
|
11
|
+
"""Exception raised when a query exceeds its timeout."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, seconds: float):
|
|
14
|
+
self.timeout = seconds
|
|
15
|
+
super().__init__(f"Query exceeded timeout of {seconds}s")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ColumnInfo(TypedDict):
|
|
19
|
+
"""Type definition for column information."""
|
|
20
|
+
|
|
21
|
+
data_type: str
|
|
22
|
+
nullable: bool
|
|
23
|
+
default: str | None
|
|
24
|
+
max_length: int | None
|
|
25
|
+
precision: int | None
|
|
26
|
+
scale: int | None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ForeignKeyInfo(TypedDict):
|
|
30
|
+
"""Type definition for foreign key information."""
|
|
31
|
+
|
|
32
|
+
column: str
|
|
33
|
+
references: dict[str, str] # {"table": "schema.table", "column": "column_name"}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class IndexInfo(TypedDict):
|
|
37
|
+
"""Type definition for index information."""
|
|
38
|
+
|
|
39
|
+
name: str
|
|
40
|
+
columns: list[str] # ordered
|
|
41
|
+
unique: bool
|
|
42
|
+
type: str | None # btree, gin, FULLTEXT, etc. None if unknown
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SchemaInfo(TypedDict):
|
|
46
|
+
"""Type definition for schema information."""
|
|
47
|
+
|
|
48
|
+
schema: str
|
|
49
|
+
name: str
|
|
50
|
+
type: str
|
|
51
|
+
columns: dict[str, ColumnInfo]
|
|
52
|
+
primary_keys: list[str]
|
|
53
|
+
foreign_keys: list[ForeignKeyInfo]
|
|
54
|
+
indexes: list[IndexInfo]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class BaseDatabaseConnection(ABC):
|
|
58
|
+
"""Abstract base class for database connections."""
|
|
59
|
+
|
|
60
|
+
def __init__(self, connection_string: str):
|
|
61
|
+
self.connection_string = connection_string
|
|
62
|
+
self._pool = None
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
async def get_pool(self):
|
|
66
|
+
"""Get or create connection pool."""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
async def close(self):
|
|
71
|
+
"""Close the connection pool."""
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
async def execute_query(
|
|
76
|
+
self, query: str, *args, timeout: float | None = None
|
|
77
|
+
) -> list[dict[str, Any]]:
|
|
78
|
+
"""Execute a query and return results as list of dicts.
|
|
79
|
+
|
|
80
|
+
All queries run in a transaction that is rolled back at the end,
|
|
81
|
+
ensuring no changes are persisted to the database.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
query: SQL query to execute
|
|
85
|
+
*args: Query parameters
|
|
86
|
+
timeout: Query timeout in seconds (overrides default_timeout)
|
|
87
|
+
"""
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class BaseSchemaIntrospector(ABC):
|
|
92
|
+
"""Abstract base class for database-specific schema introspection."""
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
async def get_tables_info(
|
|
96
|
+
self, connection, table_pattern: str | None = None
|
|
97
|
+
) -> dict[str, Any]:
|
|
98
|
+
"""Get tables information for the specific database type."""
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
async def get_columns_info(self, connection, tables: list) -> list:
|
|
103
|
+
"""Get columns information for the specific database type."""
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
@abstractmethod
|
|
107
|
+
async def get_foreign_keys_info(self, connection, tables: list) -> list:
|
|
108
|
+
"""Get foreign keys information for the specific database type."""
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
async def get_primary_keys_info(self, connection, tables: list) -> list:
|
|
113
|
+
"""Get primary keys information for the specific database type."""
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
async def get_indexes_info(self, connection, tables: list) -> list:
|
|
118
|
+
"""Get indexes information for the specific database type."""
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
@abstractmethod
|
|
122
|
+
async def list_tables_info(self, connection) -> list[dict[str, Any]]:
|
|
123
|
+
"""Get list of tables with basic information."""
|
|
124
|
+
pass
|
sqlsaber/database/csv.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""CSV database connection using DuckDB backend."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
from urllib.parse import parse_qs, urlparse
|
|
7
|
+
|
|
8
|
+
import duckdb
|
|
9
|
+
|
|
10
|
+
from .base import DEFAULT_QUERY_TIMEOUT, BaseDatabaseConnection, QueryTimeoutError
|
|
11
|
+
from .duckdb import DuckDBSchemaIntrospector
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _execute_duckdb_transaction(
|
|
15
|
+
conn: duckdb.DuckDBPyConnection, query: str, args: tuple[Any, ...]
|
|
16
|
+
) -> list[dict[str, Any]]:
|
|
17
|
+
"""Run a DuckDB query inside a transaction and return list of dicts."""
|
|
18
|
+
conn.execute("BEGIN TRANSACTION")
|
|
19
|
+
try:
|
|
20
|
+
if args:
|
|
21
|
+
conn.execute(query, args)
|
|
22
|
+
else:
|
|
23
|
+
conn.execute(query)
|
|
24
|
+
|
|
25
|
+
if conn.description is None:
|
|
26
|
+
rows: list[dict[str, Any]] = []
|
|
27
|
+
else:
|
|
28
|
+
columns = [col[0] for col in conn.description]
|
|
29
|
+
data = conn.fetchall()
|
|
30
|
+
rows = [dict(zip(columns, row)) for row in data]
|
|
31
|
+
|
|
32
|
+
conn.execute("ROLLBACK")
|
|
33
|
+
return rows
|
|
34
|
+
except Exception:
|
|
35
|
+
conn.execute("ROLLBACK")
|
|
36
|
+
raise
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class CSVConnection(BaseDatabaseConnection):
|
|
40
|
+
"""CSV file connection using DuckDB per query."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, connection_string: str):
|
|
43
|
+
super().__init__(connection_string)
|
|
44
|
+
|
|
45
|
+
raw_path = connection_string.replace("csv:///", "", 1)
|
|
46
|
+
self.csv_path = raw_path.split("?", 1)[0]
|
|
47
|
+
|
|
48
|
+
self.delimiter = ","
|
|
49
|
+
self.encoding = "utf-8"
|
|
50
|
+
self.has_header = True
|
|
51
|
+
|
|
52
|
+
parsed = urlparse(connection_string)
|
|
53
|
+
if parsed.query:
|
|
54
|
+
params = parse_qs(parsed.query)
|
|
55
|
+
self.delimiter = params.get("delimiter", [self.delimiter])[0]
|
|
56
|
+
self.encoding = params.get("encoding", [self.encoding])[0]
|
|
57
|
+
self.has_header = params.get("header", ["true"])[0].lower() == "true"
|
|
58
|
+
|
|
59
|
+
self.table_name = Path(self.csv_path).stem or "csv_table"
|
|
60
|
+
|
|
61
|
+
async def get_pool(self):
|
|
62
|
+
"""CSV connections do not maintain a pool."""
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
async def close(self):
|
|
66
|
+
"""No persistent resources to close for CSV connections."""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
def _quote_identifier(self, identifier: str) -> str:
|
|
70
|
+
escaped = identifier.replace('"', '""')
|
|
71
|
+
return f'"{escaped}"'
|
|
72
|
+
|
|
73
|
+
def _quote_literal(self, value: str) -> str:
|
|
74
|
+
escaped = value.replace("'", "''")
|
|
75
|
+
return f"'{escaped}'"
|
|
76
|
+
|
|
77
|
+
def _normalized_encoding(self) -> str | None:
|
|
78
|
+
encoding = (self.encoding or "").strip()
|
|
79
|
+
if not encoding or encoding.lower() == "utf-8":
|
|
80
|
+
return None
|
|
81
|
+
return encoding.replace("-", "").replace("_", "").upper()
|
|
82
|
+
|
|
83
|
+
def _create_view(self, conn: duckdb.DuckDBPyConnection) -> None:
|
|
84
|
+
header_literal = "TRUE" if self.has_header else "FALSE"
|
|
85
|
+
option_parts = [f"HEADER={header_literal}"]
|
|
86
|
+
|
|
87
|
+
if self.delimiter:
|
|
88
|
+
option_parts.append(f"DELIM={self._quote_literal(self.delimiter)}")
|
|
89
|
+
|
|
90
|
+
encoding = self._normalized_encoding()
|
|
91
|
+
if encoding:
|
|
92
|
+
option_parts.append(f"ENCODING={self._quote_literal(encoding)}")
|
|
93
|
+
|
|
94
|
+
options_sql = ""
|
|
95
|
+
if option_parts:
|
|
96
|
+
options_sql = ", " + ", ".join(option_parts)
|
|
97
|
+
|
|
98
|
+
base_relation_sql = (
|
|
99
|
+
f"read_csv_auto({self._quote_literal(self.csv_path)}{options_sql})"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
create_view_sql = (
|
|
103
|
+
f"CREATE VIEW {self._quote_identifier(self.table_name)} AS "
|
|
104
|
+
f"SELECT * FROM {base_relation_sql}"
|
|
105
|
+
)
|
|
106
|
+
conn.execute(create_view_sql)
|
|
107
|
+
|
|
108
|
+
async def execute_query(
|
|
109
|
+
self, query: str, *args, timeout: float | None = None
|
|
110
|
+
) -> list[dict[str, Any]]:
|
|
111
|
+
effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
|
|
112
|
+
args_tuple = tuple(args) if args else tuple()
|
|
113
|
+
|
|
114
|
+
def _run_query() -> list[dict[str, Any]]:
|
|
115
|
+
conn = duckdb.connect(":memory:")
|
|
116
|
+
try:
|
|
117
|
+
self._create_view(conn)
|
|
118
|
+
return _execute_duckdb_transaction(conn, query, args_tuple)
|
|
119
|
+
finally:
|
|
120
|
+
conn.close()
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
return await asyncio.wait_for(
|
|
124
|
+
asyncio.to_thread(_run_query), timeout=effective_timeout
|
|
125
|
+
)
|
|
126
|
+
except asyncio.TimeoutError as exc:
|
|
127
|
+
raise QueryTimeoutError(effective_timeout or 0) from exc
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class CSVSchemaIntrospector(DuckDBSchemaIntrospector):
|
|
131
|
+
"""CSV-specific schema introspection using DuckDB backend."""
|
|
132
|
+
|
|
133
|
+
pass
|