sqlsaber 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -2
- sqlsaber/agents/base.py +5 -2
- sqlsaber/agents/mcp.py +1 -1
- sqlsaber/agents/pydantic_ai_agent.py +208 -133
- sqlsaber/cli/commands.py +17 -26
- sqlsaber/cli/completers.py +2 -0
- sqlsaber/cli/database.py +18 -7
- sqlsaber/cli/display.py +29 -9
- sqlsaber/cli/interactive.py +28 -16
- sqlsaber/cli/streaming.py +15 -17
- sqlsaber/cli/threads.py +10 -6
- sqlsaber/config/database.py +3 -1
- sqlsaber/config/settings.py +25 -2
- sqlsaber/database/__init__.py +55 -1
- sqlsaber/database/base.py +124 -0
- sqlsaber/database/csv.py +133 -0
- sqlsaber/database/duckdb.py +313 -0
- sqlsaber/database/mysql.py +345 -0
- sqlsaber/database/postgresql.py +328 -0
- sqlsaber/database/resolver.py +7 -3
- sqlsaber/database/schema.py +69 -742
- sqlsaber/database/sqlite.py +258 -0
- sqlsaber/mcp/mcp.py +1 -1
- sqlsaber/tools/sql_tools.py +1 -1
- {sqlsaber-0.24.0.dist-info → sqlsaber-0.26.0.dist-info}/METADATA +45 -10
- sqlsaber-0.26.0.dist-info/RECORD +52 -0
- sqlsaber/database/connection.py +0 -511
- sqlsaber-0.24.0.dist-info/RECORD +0 -47
- {sqlsaber-0.24.0.dist-info → sqlsaber-0.26.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.24.0.dist-info → sqlsaber-0.26.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.24.0.dist-info → sqlsaber-0.26.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/display.py
CHANGED
|
@@ -8,7 +8,7 @@ rendered with Live.
|
|
|
8
8
|
import json
|
|
9
9
|
from typing import Sequence, Type
|
|
10
10
|
|
|
11
|
-
from pydantic_ai.messages import ModelResponsePart, TextPart
|
|
11
|
+
from pydantic_ai.messages import ModelResponsePart, TextPart, ThinkingPart
|
|
12
12
|
from rich.columns import Columns
|
|
13
13
|
from rich.console import Console, ConsoleOptions, RenderResult
|
|
14
14
|
from rich.live import Live
|
|
@@ -75,7 +75,7 @@ class LiveMarkdownRenderer:
|
|
|
75
75
|
self.end()
|
|
76
76
|
self.paragraph_break()
|
|
77
77
|
|
|
78
|
-
self._start()
|
|
78
|
+
self._start(kind)
|
|
79
79
|
self._current_kind = kind
|
|
80
80
|
|
|
81
81
|
def append(self, text: str | None) -> None:
|
|
@@ -87,7 +87,13 @@ class LiveMarkdownRenderer:
|
|
|
87
87
|
self.ensure_segment(TextPart)
|
|
88
88
|
|
|
89
89
|
self._buffer += text
|
|
90
|
-
|
|
90
|
+
|
|
91
|
+
# Apply dim styling for thinking segments
|
|
92
|
+
if self._current_kind == ThinkingPart:
|
|
93
|
+
content = Markdown(self._buffer, style="dim")
|
|
94
|
+
self._live.update(content)
|
|
95
|
+
else:
|
|
96
|
+
self._live.update(Markdown(self._buffer))
|
|
91
97
|
|
|
92
98
|
def end(self) -> None:
|
|
93
99
|
"""Finalize and stop the current Live segment, if any."""
|
|
@@ -95,13 +101,17 @@ class LiveMarkdownRenderer:
|
|
|
95
101
|
return
|
|
96
102
|
# Persist the *final* render exactly once, then shut Live down.
|
|
97
103
|
buf = self._buffer
|
|
104
|
+
kind = self._current_kind
|
|
98
105
|
self._live.stop()
|
|
99
106
|
self._live = None
|
|
100
107
|
self._buffer = ""
|
|
101
108
|
self._current_kind = None
|
|
102
109
|
# Print the complete markdown to scroll-back for permanent reference
|
|
103
110
|
if buf:
|
|
104
|
-
|
|
111
|
+
if kind == ThinkingPart:
|
|
112
|
+
self.console.print(Text(buf, style="dim"))
|
|
113
|
+
else:
|
|
114
|
+
self.console.print(Markdown(buf))
|
|
105
115
|
|
|
106
116
|
def end_if_active(self) -> None:
|
|
107
117
|
self.end()
|
|
@@ -153,10 +163,20 @@ class LiveMarkdownRenderer:
|
|
|
153
163
|
text = Text(f" {message}", style="yellow")
|
|
154
164
|
return Columns([spinner, text], expand=False)
|
|
155
165
|
|
|
156
|
-
def _start(
|
|
166
|
+
def _start(
|
|
167
|
+
self, kind: Type[ModelResponsePart] | None = None, initial_markdown: str = ""
|
|
168
|
+
) -> None:
|
|
157
169
|
if self._live is not None:
|
|
158
170
|
self.end()
|
|
159
171
|
self._buffer = initial_markdown or ""
|
|
172
|
+
|
|
173
|
+
# Add visual styling for thinking segments
|
|
174
|
+
if kind == ThinkingPart:
|
|
175
|
+
if self.console.is_terminal:
|
|
176
|
+
self.console.print("[dim]💭 Thinking...[/dim]")
|
|
177
|
+
else:
|
|
178
|
+
self.console.print("*Thinking...*\n")
|
|
179
|
+
|
|
160
180
|
# NOTE: Use transient=True so the live widget disappears on exit,
|
|
161
181
|
# giving a clean transition to the final printed result.
|
|
162
182
|
live = Live(
|
|
@@ -219,7 +239,9 @@ class DisplayManager:
|
|
|
219
239
|
if self.console.is_terminal:
|
|
220
240
|
self.console.print("[dim bold]:gear: Executing SQL:[/dim bold]")
|
|
221
241
|
self.show_newline()
|
|
222
|
-
syntax = Syntax(
|
|
242
|
+
syntax = Syntax(
|
|
243
|
+
query, "sql", background_color="default", word_wrap=True
|
|
244
|
+
)
|
|
223
245
|
self.console.print(syntax)
|
|
224
246
|
else:
|
|
225
247
|
self.console.print("**Executing SQL:**\n")
|
|
@@ -271,9 +293,7 @@ class DisplayManager:
|
|
|
271
293
|
f"[yellow]... and {len(results) - 20} more rows[/yellow]"
|
|
272
294
|
)
|
|
273
295
|
else:
|
|
274
|
-
self.console.print(
|
|
275
|
-
f"*... and {len(results) - 20} more rows*\n"
|
|
276
|
-
)
|
|
296
|
+
self.console.print(f"*... and {len(results) - 20} more rows*\n")
|
|
277
297
|
|
|
278
298
|
def show_error(self, error_message: str):
|
|
279
299
|
"""Display error message."""
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -3,13 +3,13 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from textwrap import dedent
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
6
7
|
|
|
7
8
|
import platformdirs
|
|
8
9
|
from prompt_toolkit import PromptSession
|
|
9
10
|
from prompt_toolkit.history import FileHistory
|
|
10
11
|
from prompt_toolkit.patch_stdout import patch_stdout
|
|
11
12
|
from prompt_toolkit.styles import Style
|
|
12
|
-
from pydantic_ai import Agent
|
|
13
13
|
from rich.console import Console
|
|
14
14
|
from rich.markdown import Markdown
|
|
15
15
|
from rich.panel import Panel
|
|
@@ -21,8 +21,9 @@ from sqlsaber.cli.completers import (
|
|
|
21
21
|
)
|
|
22
22
|
from sqlsaber.cli.display import DisplayManager
|
|
23
23
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
24
|
-
from sqlsaber.database
|
|
24
|
+
from sqlsaber.database import (
|
|
25
25
|
CSVConnection,
|
|
26
|
+
DuckDBConnection,
|
|
26
27
|
MySQLConnection,
|
|
27
28
|
PostgreSQLConnection,
|
|
28
29
|
SQLiteConnection,
|
|
@@ -30,6 +31,9 @@ from sqlsaber.database.connection import (
|
|
|
30
31
|
from sqlsaber.database.schema import SchemaManager
|
|
31
32
|
from sqlsaber.threads import ThreadStorage
|
|
32
33
|
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
|
|
36
|
+
|
|
33
37
|
|
|
34
38
|
def bottom_toolbar():
|
|
35
39
|
return [
|
|
@@ -54,7 +58,7 @@ class InteractiveSession:
|
|
|
54
58
|
def __init__(
|
|
55
59
|
self,
|
|
56
60
|
console: Console,
|
|
57
|
-
|
|
61
|
+
sqlsaber_agent: "SQLSaberAgent",
|
|
58
62
|
db_conn,
|
|
59
63
|
database_name: str,
|
|
60
64
|
*,
|
|
@@ -62,7 +66,7 @@ class InteractiveSession:
|
|
|
62
66
|
initial_history: list | None = None,
|
|
63
67
|
):
|
|
64
68
|
self.console = console
|
|
65
|
-
self.
|
|
69
|
+
self.sqlsaber_agent = sqlsaber_agent
|
|
66
70
|
self.db_conn = db_conn
|
|
67
71
|
self.database_name = database_name
|
|
68
72
|
self.display = DisplayManager(console)
|
|
@@ -85,8 +89,12 @@ class InteractiveSession:
|
|
|
85
89
|
if isinstance(self.db_conn, PostgreSQLConnection)
|
|
86
90
|
else "MySQL"
|
|
87
91
|
if isinstance(self.db_conn, MySQLConnection)
|
|
92
|
+
else "DuckDB"
|
|
93
|
+
if isinstance(self.db_conn, DuckDBConnection)
|
|
94
|
+
else "DuckDB"
|
|
95
|
+
if isinstance(self.db_conn, CSVConnection)
|
|
88
96
|
else "SQLite"
|
|
89
|
-
if isinstance(self.db_conn,
|
|
97
|
+
if isinstance(self.db_conn, SQLiteConnection)
|
|
90
98
|
else "database"
|
|
91
99
|
)
|
|
92
100
|
|
|
@@ -171,7 +179,7 @@ class InteractiveSession:
|
|
|
171
179
|
query_task = asyncio.create_task(
|
|
172
180
|
self.streaming_handler.execute_streaming_query(
|
|
173
181
|
user_query,
|
|
174
|
-
self.
|
|
182
|
+
self.sqlsaber_agent,
|
|
175
183
|
self.cancellation_token,
|
|
176
184
|
self.message_history,
|
|
177
185
|
)
|
|
@@ -186,11 +194,6 @@ class InteractiveSession:
|
|
|
186
194
|
# Use all_messages() so the system prompt and all prior turns are preserved
|
|
187
195
|
self.message_history = run_result.all_messages()
|
|
188
196
|
|
|
189
|
-
# Extract title (first user prompt) and model name
|
|
190
|
-
if not self._thread_id:
|
|
191
|
-
title = user_query
|
|
192
|
-
model_name = self.agent.model.model_name
|
|
193
|
-
|
|
194
197
|
# Persist snapshot to thread storage (create or overwrite)
|
|
195
198
|
self._thread_id = await self._threads.save_snapshot(
|
|
196
199
|
messages_json=run_result.all_messages_json(),
|
|
@@ -201,8 +204,8 @@ class InteractiveSession:
|
|
|
201
204
|
if self.first_message:
|
|
202
205
|
await self._threads.save_metadata(
|
|
203
206
|
thread_id=self._thread_id,
|
|
204
|
-
title=
|
|
205
|
-
model_name=model_name,
|
|
207
|
+
title=user_query,
|
|
208
|
+
model_name=self.sqlsaber_agent.agent.model.model_name,
|
|
206
209
|
)
|
|
207
210
|
except Exception:
|
|
208
211
|
pass
|
|
@@ -264,6 +267,17 @@ class InteractiveSession:
|
|
|
264
267
|
self._thread_id = None
|
|
265
268
|
continue
|
|
266
269
|
|
|
270
|
+
# Thinking commands
|
|
271
|
+
if user_query == "/thinking on":
|
|
272
|
+
self.sqlsaber_agent.set_thinking(enabled=True)
|
|
273
|
+
self.console.print("[green]✓ Thinking enabled[/green]\n")
|
|
274
|
+
continue
|
|
275
|
+
|
|
276
|
+
if user_query == "/thinking off":
|
|
277
|
+
self.sqlsaber_agent.set_thinking(enabled=False)
|
|
278
|
+
self.console.print("[green]✓ Thinking disabled[/green]\n")
|
|
279
|
+
continue
|
|
280
|
+
|
|
267
281
|
if memory_text := user_query.strip():
|
|
268
282
|
# Check if query starts with # for memory addition
|
|
269
283
|
if memory_text.startswith("#"):
|
|
@@ -271,9 +285,7 @@ class InteractiveSession:
|
|
|
271
285
|
if memory_content:
|
|
272
286
|
# Add memory via the agent's memory manager
|
|
273
287
|
try:
|
|
274
|
-
mm =
|
|
275
|
-
self.agent, "_sqlsaber_memory_manager", None
|
|
276
|
-
)
|
|
288
|
+
mm = self.sqlsaber_agent.memory_manager
|
|
277
289
|
if mm and self.database_name:
|
|
278
290
|
memory = mm.add_memory(
|
|
279
291
|
self.database_name, memory_content
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -8,9 +8,9 @@ rendered via DisplayManager helpers.
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import json
|
|
10
10
|
from functools import singledispatchmethod
|
|
11
|
-
from typing import AsyncIterable
|
|
11
|
+
from typing import TYPE_CHECKING, AsyncIterable
|
|
12
12
|
|
|
13
|
-
from pydantic_ai import
|
|
13
|
+
from pydantic_ai import RunContext
|
|
14
14
|
from pydantic_ai.messages import (
|
|
15
15
|
AgentStreamEvent,
|
|
16
16
|
FunctionToolCallEvent,
|
|
@@ -26,6 +26,9 @@ from rich.console import Console
|
|
|
26
26
|
|
|
27
27
|
from sqlsaber.cli.display import DisplayManager
|
|
28
28
|
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
|
|
31
|
+
|
|
29
32
|
|
|
30
33
|
class StreamingQueryHandler:
|
|
31
34
|
"""
|
|
@@ -130,7 +133,7 @@ class StreamingQueryHandler:
|
|
|
130
133
|
async def execute_streaming_query(
|
|
131
134
|
self,
|
|
132
135
|
user_query: str,
|
|
133
|
-
|
|
136
|
+
sqlsaber_agent: "SQLSaberAgent",
|
|
134
137
|
cancellation_token: asyncio.Event | None = None,
|
|
135
138
|
message_history: list | None = None,
|
|
136
139
|
):
|
|
@@ -139,21 +142,16 @@ class StreamingQueryHandler:
|
|
|
139
142
|
try:
|
|
140
143
|
# If Anthropic OAuth, inject SQLsaber instructions before the first user prompt
|
|
141
144
|
prepared_prompt: str | list[str] = user_query
|
|
142
|
-
is_oauth = bool(getattr(agent, "_sqlsaber_is_oauth", False))
|
|
143
145
|
no_history = not message_history
|
|
144
|
-
if is_oauth and no_history:
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
db_type = getattr(agent, "_sqlsaber_db_type", "database")
|
|
148
|
-
db_name = getattr(agent, "_sqlsaber_database_name", None)
|
|
149
|
-
instructions = (
|
|
150
|
-
ib.build_instructions(db_type=db_type) if ib is not None else ""
|
|
151
|
-
)
|
|
152
|
-
mem = (
|
|
153
|
-
mm.format_memories_for_prompt(db_name)
|
|
154
|
-
if (mm is not None and db_name)
|
|
155
|
-
else ""
|
|
146
|
+
if sqlsaber_agent.is_oauth and no_history:
|
|
147
|
+
instructions = sqlsaber_agent.instruction_builder.build_instructions(
|
|
148
|
+
db_type=sqlsaber_agent.db_type
|
|
156
149
|
)
|
|
150
|
+
mem = ""
|
|
151
|
+
if sqlsaber_agent.database_name:
|
|
152
|
+
mem = sqlsaber_agent.memory_manager.format_memories_for_prompt(
|
|
153
|
+
sqlsaber_agent.database_name
|
|
154
|
+
)
|
|
157
155
|
parts = [p for p in (instructions, mem) if p and str(p).strip()]
|
|
158
156
|
if parts:
|
|
159
157
|
injected = "\n\n".join(parts)
|
|
@@ -163,7 +161,7 @@ class StreamingQueryHandler:
|
|
|
163
161
|
self.display.live.start_status("Crunching data...")
|
|
164
162
|
|
|
165
163
|
# Run the agent with our event stream handler
|
|
166
|
-
run = await agent.run(
|
|
164
|
+
run = await sqlsaber_agent.agent.run(
|
|
167
165
|
prepared_prompt,
|
|
168
166
|
message_history=message_history,
|
|
169
167
|
event_stream_handler=self._event_stream_handler,
|
sqlsaber/cli/threads.py
CHANGED
|
@@ -148,7 +148,9 @@ def _render_transcript(
|
|
|
148
148
|
)
|
|
149
149
|
else:
|
|
150
150
|
if is_redirected:
|
|
151
|
-
console.print(
|
|
151
|
+
console.print(
|
|
152
|
+
f"**Tool result ({name}):**\n\n{content_str}\n"
|
|
153
|
+
)
|
|
152
154
|
else:
|
|
153
155
|
console.print(
|
|
154
156
|
Panel.fit(
|
|
@@ -159,7 +161,9 @@ def _render_transcript(
|
|
|
159
161
|
)
|
|
160
162
|
except Exception:
|
|
161
163
|
if is_redirected:
|
|
162
|
-
console.print(
|
|
164
|
+
console.print(
|
|
165
|
+
f"**Tool result ({name}):**\n\n{content_str}\n"
|
|
166
|
+
)
|
|
163
167
|
else:
|
|
164
168
|
console.print(
|
|
165
169
|
Panel.fit(
|
|
@@ -258,10 +262,10 @@ def resume(
|
|
|
258
262
|
|
|
259
263
|
async def _run() -> None:
|
|
260
264
|
# Lazy imports to avoid heavy modules at CLI startup
|
|
261
|
-
from sqlsaber.agents import
|
|
265
|
+
from sqlsaber.agents import SQLSaberAgent
|
|
262
266
|
from sqlsaber.cli.interactive import InteractiveSession
|
|
263
267
|
from sqlsaber.config.database import DatabaseConfigManager
|
|
264
|
-
from sqlsaber.database
|
|
268
|
+
from sqlsaber.database import DatabaseConnection
|
|
265
269
|
from sqlsaber.database.resolver import (
|
|
266
270
|
DatabaseResolutionError,
|
|
267
271
|
resolve_database,
|
|
@@ -288,7 +292,7 @@ def resume(
|
|
|
288
292
|
|
|
289
293
|
db_conn = DatabaseConnection(connection_string)
|
|
290
294
|
try:
|
|
291
|
-
|
|
295
|
+
sqlsaber_agent = SQLSaberAgent(db_conn, db_name)
|
|
292
296
|
history = await store.get_thread_messages(thread_id)
|
|
293
297
|
if console.is_terminal:
|
|
294
298
|
console.print(Panel.fit(f"Thread: {thread.id}", border_style="blue"))
|
|
@@ -297,7 +301,7 @@ def resume(
|
|
|
297
301
|
_render_transcript(console, history, None)
|
|
298
302
|
session = InteractiveSession(
|
|
299
303
|
console=console,
|
|
300
|
-
|
|
304
|
+
sqlsaber_agent=sqlsaber_agent,
|
|
301
305
|
db_conn=db_conn,
|
|
302
306
|
database_name=db_name,
|
|
303
307
|
initial_thread_id=thread_id,
|
sqlsaber/config/database.py
CHANGED
|
@@ -18,7 +18,7 @@ class DatabaseConfig:
|
|
|
18
18
|
"""Database connection configuration."""
|
|
19
19
|
|
|
20
20
|
name: str
|
|
21
|
-
type: str # postgresql, mysql, sqlite, csv
|
|
21
|
+
type: str # postgresql, mysql, sqlite, duckdb, csv
|
|
22
22
|
host: str | None
|
|
23
23
|
port: int | None
|
|
24
24
|
database: str
|
|
@@ -90,6 +90,8 @@ class DatabaseConfig:
|
|
|
90
90
|
|
|
91
91
|
elif self.type == "sqlite":
|
|
92
92
|
return f"sqlite:///{self.database}"
|
|
93
|
+
elif self.type == "duckdb":
|
|
94
|
+
return f"duckdb:///{self.database}"
|
|
93
95
|
elif self.type == "csv":
|
|
94
96
|
# For CSV files, database field contains the file path
|
|
95
97
|
base_url = f"csv:///{self.database}"
|
sqlsaber/config/settings.py
CHANGED
|
@@ -46,7 +46,10 @@ class ModelConfigManager:
|
|
|
46
46
|
def _load_config(self) -> dict[str, Any]:
|
|
47
47
|
"""Load configuration from file."""
|
|
48
48
|
if not self.config_file.exists():
|
|
49
|
-
return {
|
|
49
|
+
return {
|
|
50
|
+
"model": self.DEFAULT_MODEL,
|
|
51
|
+
"thinking_enabled": False,
|
|
52
|
+
}
|
|
50
53
|
|
|
51
54
|
try:
|
|
52
55
|
with open(self.config_file, "r") as f:
|
|
@@ -54,9 +57,15 @@ class ModelConfigManager:
|
|
|
54
57
|
# Ensure we have a model set
|
|
55
58
|
if "model" not in config:
|
|
56
59
|
config["model"] = self.DEFAULT_MODEL
|
|
60
|
+
# Set defaults for thinking if not present
|
|
61
|
+
if "thinking_enabled" not in config:
|
|
62
|
+
config["thinking_enabled"] = False
|
|
57
63
|
return config
|
|
58
64
|
except (json.JSONDecodeError, IOError):
|
|
59
|
-
return {
|
|
65
|
+
return {
|
|
66
|
+
"model": self.DEFAULT_MODEL,
|
|
67
|
+
"thinking_enabled": False,
|
|
68
|
+
}
|
|
60
69
|
|
|
61
70
|
def _save_config(self, config: dict[str, Any]) -> None:
|
|
62
71
|
"""Save configuration to file."""
|
|
@@ -76,6 +85,17 @@ class ModelConfigManager:
|
|
|
76
85
|
config["model"] = model
|
|
77
86
|
self._save_config(config)
|
|
78
87
|
|
|
88
|
+
def get_thinking_enabled(self) -> bool:
|
|
89
|
+
"""Get whether thinking is enabled."""
|
|
90
|
+
config = self._load_config()
|
|
91
|
+
return config.get("thinking_enabled", False)
|
|
92
|
+
|
|
93
|
+
def set_thinking_enabled(self, enabled: bool) -> None:
|
|
94
|
+
"""Set whether thinking is enabled."""
|
|
95
|
+
config = self._load_config()
|
|
96
|
+
config["thinking_enabled"] = enabled
|
|
97
|
+
self._save_config(config)
|
|
98
|
+
|
|
79
99
|
|
|
80
100
|
class Config:
|
|
81
101
|
"""Configuration class for SQLSaber."""
|
|
@@ -86,6 +106,9 @@ class Config:
|
|
|
86
106
|
self.api_key_manager = APIKeyManager()
|
|
87
107
|
self.auth_config_manager = AuthConfigManager()
|
|
88
108
|
|
|
109
|
+
# Thinking configuration
|
|
110
|
+
self.thinking_enabled = self.model_config_manager.get_thinking_enabled()
|
|
111
|
+
|
|
89
112
|
# Authentication method (API key or Anthropic OAuth)
|
|
90
113
|
self.auth_method = self.auth_config_manager.get_auth_method()
|
|
91
114
|
|
sqlsaber/database/__init__.py
CHANGED
|
@@ -1,9 +1,63 @@
|
|
|
1
1
|
"""Database module for SQLSaber."""
|
|
2
2
|
|
|
3
|
-
from .
|
|
3
|
+
from .base import (
|
|
4
|
+
DEFAULT_QUERY_TIMEOUT,
|
|
5
|
+
BaseDatabaseConnection,
|
|
6
|
+
BaseSchemaIntrospector,
|
|
7
|
+
ColumnInfo,
|
|
8
|
+
ForeignKeyInfo,
|
|
9
|
+
IndexInfo,
|
|
10
|
+
QueryTimeoutError,
|
|
11
|
+
SchemaInfo,
|
|
12
|
+
)
|
|
13
|
+
from .csv import CSVConnection, CSVSchemaIntrospector
|
|
14
|
+
from .duckdb import DuckDBConnection, DuckDBSchemaIntrospector
|
|
15
|
+
from .mysql import MySQLConnection, MySQLSchemaIntrospector
|
|
16
|
+
from .postgresql import PostgreSQLConnection, PostgreSQLSchemaIntrospector
|
|
4
17
|
from .schema import SchemaManager
|
|
18
|
+
from .sqlite import SQLiteConnection, SQLiteSchemaIntrospector
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def DatabaseConnection(connection_string: str) -> BaseDatabaseConnection:
|
|
22
|
+
"""Factory function to create appropriate database connection based on connection string."""
|
|
23
|
+
if connection_string.startswith("postgresql://"):
|
|
24
|
+
return PostgreSQLConnection(connection_string)
|
|
25
|
+
elif connection_string.startswith("mysql://"):
|
|
26
|
+
return MySQLConnection(connection_string)
|
|
27
|
+
elif connection_string.startswith("sqlite:///"):
|
|
28
|
+
return SQLiteConnection(connection_string)
|
|
29
|
+
elif connection_string.startswith("duckdb://"):
|
|
30
|
+
return DuckDBConnection(connection_string)
|
|
31
|
+
elif connection_string.startswith("csv:///"):
|
|
32
|
+
return CSVConnection(connection_string)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Unsupported database type in connection string: {connection_string}"
|
|
36
|
+
)
|
|
37
|
+
|
|
5
38
|
|
|
6
39
|
__all__ = [
|
|
40
|
+
# Base classes and types
|
|
41
|
+
"BaseDatabaseConnection",
|
|
42
|
+
"BaseSchemaIntrospector",
|
|
43
|
+
"ColumnInfo",
|
|
44
|
+
"DEFAULT_QUERY_TIMEOUT",
|
|
45
|
+
"ForeignKeyInfo",
|
|
46
|
+
"IndexInfo",
|
|
47
|
+
"QueryTimeoutError",
|
|
48
|
+
"SchemaInfo",
|
|
49
|
+
# Concrete implementations
|
|
50
|
+
"PostgreSQLConnection",
|
|
51
|
+
"MySQLConnection",
|
|
52
|
+
"SQLiteConnection",
|
|
53
|
+
"DuckDBConnection",
|
|
54
|
+
"CSVConnection",
|
|
55
|
+
"PostgreSQLSchemaIntrospector",
|
|
56
|
+
"MySQLSchemaIntrospector",
|
|
57
|
+
"SQLiteSchemaIntrospector",
|
|
58
|
+
"DuckDBSchemaIntrospector",
|
|
59
|
+
"CSVSchemaIntrospector",
|
|
60
|
+
# Factory function and manager
|
|
7
61
|
"DatabaseConnection",
|
|
8
62
|
"SchemaManager",
|
|
9
63
|
]
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Base classes and type definitions for database connections and schema introspection."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, TypedDict
|
|
5
|
+
|
|
6
|
+
# Default query timeout to prevent runaway queries
|
|
7
|
+
DEFAULT_QUERY_TIMEOUT = 30.0 # seconds
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class QueryTimeoutError(RuntimeError):
|
|
11
|
+
"""Exception raised when a query exceeds its timeout."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, seconds: float):
|
|
14
|
+
self.timeout = seconds
|
|
15
|
+
super().__init__(f"Query exceeded timeout of {seconds}s")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ColumnInfo(TypedDict):
|
|
19
|
+
"""Type definition for column information."""
|
|
20
|
+
|
|
21
|
+
data_type: str
|
|
22
|
+
nullable: bool
|
|
23
|
+
default: str | None
|
|
24
|
+
max_length: int | None
|
|
25
|
+
precision: int | None
|
|
26
|
+
scale: int | None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ForeignKeyInfo(TypedDict):
|
|
30
|
+
"""Type definition for foreign key information."""
|
|
31
|
+
|
|
32
|
+
column: str
|
|
33
|
+
references: dict[str, str] # {"table": "schema.table", "column": "column_name"}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class IndexInfo(TypedDict):
|
|
37
|
+
"""Type definition for index information."""
|
|
38
|
+
|
|
39
|
+
name: str
|
|
40
|
+
columns: list[str] # ordered
|
|
41
|
+
unique: bool
|
|
42
|
+
type: str | None # btree, gin, FULLTEXT, etc. None if unknown
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SchemaInfo(TypedDict):
|
|
46
|
+
"""Type definition for schema information."""
|
|
47
|
+
|
|
48
|
+
schema: str
|
|
49
|
+
name: str
|
|
50
|
+
type: str
|
|
51
|
+
columns: dict[str, ColumnInfo]
|
|
52
|
+
primary_keys: list[str]
|
|
53
|
+
foreign_keys: list[ForeignKeyInfo]
|
|
54
|
+
indexes: list[IndexInfo]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class BaseDatabaseConnection(ABC):
|
|
58
|
+
"""Abstract base class for database connections."""
|
|
59
|
+
|
|
60
|
+
def __init__(self, connection_string: str):
|
|
61
|
+
self.connection_string = connection_string
|
|
62
|
+
self._pool = None
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
async def get_pool(self):
|
|
66
|
+
"""Get or create connection pool."""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
async def close(self):
|
|
71
|
+
"""Close the connection pool."""
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
async def execute_query(
|
|
76
|
+
self, query: str, *args, timeout: float | None = None
|
|
77
|
+
) -> list[dict[str, Any]]:
|
|
78
|
+
"""Execute a query and return results as list of dicts.
|
|
79
|
+
|
|
80
|
+
All queries run in a transaction that is rolled back at the end,
|
|
81
|
+
ensuring no changes are persisted to the database.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
query: SQL query to execute
|
|
85
|
+
*args: Query parameters
|
|
86
|
+
timeout: Query timeout in seconds (overrides default_timeout)
|
|
87
|
+
"""
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class BaseSchemaIntrospector(ABC):
|
|
92
|
+
"""Abstract base class for database-specific schema introspection."""
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
async def get_tables_info(
|
|
96
|
+
self, connection, table_pattern: str | None = None
|
|
97
|
+
) -> dict[str, Any]:
|
|
98
|
+
"""Get tables information for the specific database type."""
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
async def get_columns_info(self, connection, tables: list) -> list:
|
|
103
|
+
"""Get columns information for the specific database type."""
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
@abstractmethod
|
|
107
|
+
async def get_foreign_keys_info(self, connection, tables: list) -> list:
|
|
108
|
+
"""Get foreign keys information for the specific database type."""
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
async def get_primary_keys_info(self, connection, tables: list) -> list:
|
|
113
|
+
"""Get primary keys information for the specific database type."""
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
async def get_indexes_info(self, connection, tables: list) -> list:
|
|
118
|
+
"""Get indexes information for the specific database type."""
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
@abstractmethod
|
|
122
|
+
async def list_tables_info(self, connection) -> list[dict[str, Any]]:
|
|
123
|
+
"""Get list of tables with basic information."""
|
|
124
|
+
pass
|