sqlsaber 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/base.py +4 -1
- sqlsaber/agents/pydantic_ai_agent.py +4 -1
- sqlsaber/cli/commands.py +19 -11
- sqlsaber/cli/database.py +17 -6
- sqlsaber/cli/display.py +49 -19
- sqlsaber/cli/interactive.py +6 -1
- sqlsaber/cli/threads.py +41 -18
- sqlsaber/config/database.py +3 -1
- sqlsaber/database/connection.py +123 -99
- sqlsaber/database/resolver.py +7 -3
- sqlsaber/database/schema.py +377 -1
- sqlsaber/tools/sql_tools.py +6 -0
- {sqlsaber-0.23.0.dist-info → sqlsaber-0.25.0.dist-info}/METADATA +4 -3
- {sqlsaber-0.23.0.dist-info → sqlsaber-0.25.0.dist-info}/RECORD +17 -17
- {sqlsaber-0.23.0.dist-info → sqlsaber-0.25.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.23.0.dist-info → sqlsaber-0.25.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.23.0.dist-info → sqlsaber-0.25.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/base.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import Any, AsyncIterator
|
|
|
8
8
|
from sqlsaber.database.connection import (
|
|
9
9
|
BaseDatabaseConnection,
|
|
10
10
|
CSVConnection,
|
|
11
|
+
DuckDBConnection,
|
|
11
12
|
MySQLConnection,
|
|
12
13
|
PostgreSQLConnection,
|
|
13
14
|
SQLiteConnection,
|
|
@@ -51,7 +52,9 @@ class BaseSQLAgent(ABC):
|
|
|
51
52
|
elif isinstance(self.db, SQLiteConnection):
|
|
52
53
|
return "SQLite"
|
|
53
54
|
elif isinstance(self.db, CSVConnection):
|
|
54
|
-
return "
|
|
55
|
+
return "DuckDB"
|
|
56
|
+
elif isinstance(self.db, DuckDBConnection):
|
|
57
|
+
return "DuckDB"
|
|
55
58
|
else:
|
|
56
59
|
return "database" # Fallback
|
|
57
60
|
|
|
@@ -17,6 +17,7 @@ from sqlsaber.config.settings import Config
|
|
|
17
17
|
from sqlsaber.database.connection import (
|
|
18
18
|
BaseDatabaseConnection,
|
|
19
19
|
CSVConnection,
|
|
20
|
+
DuckDBConnection,
|
|
20
21
|
MySQLConnection,
|
|
21
22
|
PostgreSQLConnection,
|
|
22
23
|
SQLiteConnection,
|
|
@@ -169,7 +170,9 @@ def _get_database_type_name(db: BaseDatabaseConnection) -> str:
|
|
|
169
170
|
return "MySQL"
|
|
170
171
|
elif isinstance(db, SQLiteConnection):
|
|
171
172
|
return "SQLite"
|
|
173
|
+
elif isinstance(db, DuckDBConnection):
|
|
174
|
+
return "DuckDB"
|
|
172
175
|
elif isinstance(db, CSVConnection):
|
|
173
|
-
return "
|
|
176
|
+
return "DuckDB"
|
|
174
177
|
else:
|
|
175
178
|
return "database"
|
sqlsaber/cli/commands.py
CHANGED
|
@@ -46,7 +46,7 @@ def meta_handler(
|
|
|
46
46
|
str | None,
|
|
47
47
|
cyclopts.Parameter(
|
|
48
48
|
["--database", "-d"],
|
|
49
|
-
help="Database connection name, file path (CSV/SQLite), or connection string (postgresql://, mysql://) (uses default if not specified)",
|
|
49
|
+
help="Database connection name, file path (CSV/SQLite/DuckDB), or connection string (postgresql://, mysql://, duckdb://) (uses default if not specified)",
|
|
50
50
|
),
|
|
51
51
|
] = None,
|
|
52
52
|
):
|
|
@@ -59,8 +59,10 @@ def meta_handler(
|
|
|
59
59
|
saber -d mydb "show me users" # Run a query with specific database
|
|
60
60
|
saber -d data.csv "show me users" # Run a query with ad-hoc CSV file
|
|
61
61
|
saber -d data.db "show me users" # Run a query with ad-hoc SQLite file
|
|
62
|
+
saber -d data.duckdb "show me users" # Run a query with ad-hoc DuckDB file
|
|
62
63
|
saber -d "postgresql://user:pass@host:5432/db" "show users" # PostgreSQL connection string
|
|
63
64
|
saber -d "mysql://user:pass@host:3306/db" "show users" # MySQL connection string
|
|
65
|
+
saber -d "duckdb:///data.duckdb" "show users" # DuckDB connection string
|
|
64
66
|
echo "show me all users" | saber # Read query from stdin
|
|
65
67
|
cat query.txt | saber # Read query from file via stdin
|
|
66
68
|
"""
|
|
@@ -80,7 +82,7 @@ def query(
|
|
|
80
82
|
str | None,
|
|
81
83
|
cyclopts.Parameter(
|
|
82
84
|
["--database", "-d"],
|
|
83
|
-
help="Database connection name, file path (CSV/SQLite), or connection string (postgresql://, mysql://) (uses default if not specified)",
|
|
85
|
+
help="Database connection name, file path (CSV/SQLite/DuckDB), or connection string (postgresql://, mysql://, duckdb://) (uses default if not specified)",
|
|
84
86
|
),
|
|
85
87
|
] = None,
|
|
86
88
|
):
|
|
@@ -97,8 +99,10 @@ def query(
|
|
|
97
99
|
saber "show me all users" # Run a single query
|
|
98
100
|
saber -d data.csv "show users" # Run a query with ad-hoc CSV file
|
|
99
101
|
saber -d data.db "show users" # Run a query with ad-hoc SQLite file
|
|
102
|
+
saber -d data.duckdb "show users" # Run a query with ad-hoc DuckDB file
|
|
100
103
|
saber -d "postgresql://user:pass@host:5432/db" "show users" # PostgreSQL connection string
|
|
101
104
|
saber -d "mysql://user:pass@host:3306/db" "show users" # MySQL connection string
|
|
105
|
+
saber -d "duckdb:///data.duckdb" "show users" # DuckDB connection string
|
|
102
106
|
echo "show me all users" | saber # Read query from stdin
|
|
103
107
|
"""
|
|
104
108
|
|
|
@@ -111,6 +115,7 @@ def query(
|
|
|
111
115
|
from sqlsaber.database.connection import (
|
|
112
116
|
CSVConnection,
|
|
113
117
|
DatabaseConnection,
|
|
118
|
+
DuckDBConnection,
|
|
114
119
|
MySQLConnection,
|
|
115
120
|
PostgreSQLConnection,
|
|
116
121
|
SQLiteConnection,
|
|
@@ -149,15 +154,18 @@ def query(
|
|
|
149
154
|
# Single query mode with streaming
|
|
150
155
|
streaming_handler = StreamingQueryHandler(console)
|
|
151
156
|
# Compute DB type for the greeting line
|
|
152
|
-
|
|
153
|
-
"PostgreSQL"
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
)
|
|
157
|
+
if isinstance(db_conn, PostgreSQLConnection):
|
|
158
|
+
db_type = "PostgreSQL"
|
|
159
|
+
elif isinstance(db_conn, MySQLConnection):
|
|
160
|
+
db_type = "MySQL"
|
|
161
|
+
elif isinstance(db_conn, DuckDBConnection):
|
|
162
|
+
db_type = "DuckDB"
|
|
163
|
+
elif isinstance(db_conn, SQLiteConnection):
|
|
164
|
+
db_type = "SQLite"
|
|
165
|
+
elif isinstance(db_conn, CSVConnection):
|
|
166
|
+
db_type = "DuckDB"
|
|
167
|
+
else:
|
|
168
|
+
db_type = "database"
|
|
161
169
|
console.print(
|
|
162
170
|
f"[bold blue]Connected to:[/bold blue] {db_name} ({db_type})\n"
|
|
163
171
|
)
|
sqlsaber/cli/database.py
CHANGED
|
@@ -31,7 +31,7 @@ def add(
|
|
|
31
31
|
str,
|
|
32
32
|
cyclopts.Parameter(
|
|
33
33
|
["--type", "-t"],
|
|
34
|
-
help="Database type (postgresql, mysql, sqlite)",
|
|
34
|
+
help="Database type (postgresql, mysql, sqlite, duckdb)",
|
|
35
35
|
),
|
|
36
36
|
] = "postgresql",
|
|
37
37
|
host: Annotated[
|
|
@@ -87,17 +87,17 @@ def add(
|
|
|
87
87
|
if not type or type == "postgresql":
|
|
88
88
|
type = questionary.select(
|
|
89
89
|
"Database type:",
|
|
90
|
-
choices=["postgresql", "mysql", "sqlite"],
|
|
90
|
+
choices=["postgresql", "mysql", "sqlite", "duckdb"],
|
|
91
91
|
default="postgresql",
|
|
92
92
|
).ask()
|
|
93
93
|
|
|
94
|
-
if type
|
|
95
|
-
# SQLite only
|
|
94
|
+
if type in {"sqlite", "duckdb"}:
|
|
95
|
+
# SQLite/DuckDB only need database file path
|
|
96
96
|
database = database or questionary.path("Database file path:").ask()
|
|
97
97
|
database = str(Path(database).expanduser().resolve())
|
|
98
98
|
host = "localhost"
|
|
99
99
|
port = 0
|
|
100
|
-
username =
|
|
100
|
+
username = type
|
|
101
101
|
password = ""
|
|
102
102
|
else:
|
|
103
103
|
# PostgreSQL/MySQL need connection details
|
|
@@ -182,6 +182,17 @@ def add(
|
|
|
182
182
|
port = 0
|
|
183
183
|
username = "sqlite"
|
|
184
184
|
password = ""
|
|
185
|
+
elif type == "duckdb":
|
|
186
|
+
if not database:
|
|
187
|
+
console.print(
|
|
188
|
+
"[bold red]Error:[/bold red] Database file path is required for DuckDB"
|
|
189
|
+
)
|
|
190
|
+
sys.exit(1)
|
|
191
|
+
database = str(Path(database).expanduser().resolve())
|
|
192
|
+
host = "localhost"
|
|
193
|
+
port = 0
|
|
194
|
+
username = "duckdb"
|
|
195
|
+
password = ""
|
|
185
196
|
else:
|
|
186
197
|
if not all([host, database, username]):
|
|
187
198
|
console.print(
|
|
@@ -264,7 +275,7 @@ def list():
|
|
|
264
275
|
if db.ssl_ca or db.ssl_cert:
|
|
265
276
|
ssl_status += " (certs)"
|
|
266
277
|
else:
|
|
267
|
-
ssl_status = "disabled" if db.type
|
|
278
|
+
ssl_status = "disabled" if db.type not in {"sqlite", "duckdb"} else "N/A"
|
|
268
279
|
|
|
269
280
|
table.add_row(
|
|
270
281
|
db.name,
|
sqlsaber/cli/display.py
CHANGED
|
@@ -198,22 +198,32 @@ class DisplayManager:
|
|
|
198
198
|
# Normalized leading blank line before tool headers
|
|
199
199
|
self.show_newline()
|
|
200
200
|
if tool_name == "list_tables":
|
|
201
|
-
self.console.
|
|
202
|
-
|
|
203
|
-
|
|
201
|
+
if self.console.is_terminal:
|
|
202
|
+
self.console.print(
|
|
203
|
+
"[dim bold]:gear: Discovering available tables[/dim bold]"
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
self.console.print("**Discovering available tables**\n")
|
|
204
207
|
elif tool_name == "introspect_schema":
|
|
205
208
|
pattern = tool_input.get("table_pattern", "all tables")
|
|
206
|
-
self.console.
|
|
207
|
-
|
|
208
|
-
|
|
209
|
+
if self.console.is_terminal:
|
|
210
|
+
self.console.print(
|
|
211
|
+
f"[dim bold]:gear: Examining schema for: {pattern}[/dim bold]"
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
self.console.print(f"**Examining schema for:** {pattern}\n")
|
|
209
215
|
elif tool_name == "execute_sql":
|
|
210
216
|
# For streaming, we render SQL via LiveMarkdownRenderer; keep Syntax
|
|
211
217
|
# rendering for threads show/resume. Controlled by include_sql flag.
|
|
212
218
|
query = tool_input.get("query", "")
|
|
213
|
-
self.console.
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
219
|
+
if self.console.is_terminal:
|
|
220
|
+
self.console.print("[dim bold]:gear: Executing SQL:[/dim bold]")
|
|
221
|
+
self.show_newline()
|
|
222
|
+
syntax = Syntax(query, "sql", background_color="default", word_wrap=True)
|
|
223
|
+
self.console.print(syntax)
|
|
224
|
+
else:
|
|
225
|
+
self.console.print("**Executing SQL:**\n")
|
|
226
|
+
self.console.print(f"```sql\n{query}\n```\n")
|
|
217
227
|
|
|
218
228
|
def show_text_stream(self, text: str):
|
|
219
229
|
"""Display streaming text."""
|
|
@@ -225,9 +235,12 @@ class DisplayManager:
|
|
|
225
235
|
if not results:
|
|
226
236
|
return
|
|
227
237
|
|
|
228
|
-
self.console.
|
|
229
|
-
|
|
230
|
-
|
|
238
|
+
if self.console.is_terminal:
|
|
239
|
+
self.console.print(
|
|
240
|
+
f"\n[bold magenta]Results ({len(results)} rows):[/bold magenta]"
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
self.console.print(f"\n**Results ({len(results)} rows):**\n")
|
|
231
244
|
|
|
232
245
|
# Create table with columns from first result
|
|
233
246
|
all_columns = list(results[0].keys())
|
|
@@ -235,9 +248,14 @@ class DisplayManager:
|
|
|
235
248
|
|
|
236
249
|
# Show warning if columns were truncated
|
|
237
250
|
if len(all_columns) > 15:
|
|
238
|
-
self.console.
|
|
239
|
-
|
|
240
|
-
|
|
251
|
+
if self.console.is_terminal:
|
|
252
|
+
self.console.print(
|
|
253
|
+
f"[yellow]Note: Showing first 15 of {len(all_columns)} columns[/yellow]"
|
|
254
|
+
)
|
|
255
|
+
else:
|
|
256
|
+
self.console.print(
|
|
257
|
+
f"*Note: Showing first 15 of {len(all_columns)} columns*\n"
|
|
258
|
+
)
|
|
241
259
|
|
|
242
260
|
table = self._create_table(display_columns)
|
|
243
261
|
|
|
@@ -248,9 +266,14 @@ class DisplayManager:
|
|
|
248
266
|
self.console.print(table)
|
|
249
267
|
|
|
250
268
|
if len(results) > 20:
|
|
251
|
-
self.console.
|
|
252
|
-
|
|
253
|
-
|
|
269
|
+
if self.console.is_terminal:
|
|
270
|
+
self.console.print(
|
|
271
|
+
f"[yellow]... and {len(results) - 20} more rows[/yellow]"
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
self.console.print(
|
|
275
|
+
f"*... and {len(results) - 20} more rows*\n"
|
|
276
|
+
)
|
|
254
277
|
|
|
255
278
|
def show_error(self, error_message: str):
|
|
256
279
|
"""Display error message."""
|
|
@@ -385,6 +408,13 @@ class DisplayManager:
|
|
|
385
408
|
for fk in foreign_keys:
|
|
386
409
|
self.console.print(f" • {fk}")
|
|
387
410
|
|
|
411
|
+
# Show indexes
|
|
412
|
+
indexes = table_info.get("indexes", [])
|
|
413
|
+
if indexes:
|
|
414
|
+
self.console.print("[bold blue]Indexes:[/bold blue]")
|
|
415
|
+
for idx in indexes:
|
|
416
|
+
self.console.print(f" • {idx}")
|
|
417
|
+
|
|
388
418
|
except json.JSONDecodeError:
|
|
389
419
|
self.show_error("Failed to parse schema data")
|
|
390
420
|
except Exception as e:
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -23,6 +23,7 @@ from sqlsaber.cli.display import DisplayManager
|
|
|
23
23
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
24
24
|
from sqlsaber.database.connection import (
|
|
25
25
|
CSVConnection,
|
|
26
|
+
DuckDBConnection,
|
|
26
27
|
MySQLConnection,
|
|
27
28
|
PostgreSQLConnection,
|
|
28
29
|
SQLiteConnection,
|
|
@@ -85,8 +86,12 @@ class InteractiveSession:
|
|
|
85
86
|
if isinstance(self.db_conn, PostgreSQLConnection)
|
|
86
87
|
else "MySQL"
|
|
87
88
|
if isinstance(self.db_conn, MySQLConnection)
|
|
89
|
+
else "DuckDB"
|
|
90
|
+
if isinstance(self.db_conn, DuckDBConnection)
|
|
91
|
+
else "DuckDB"
|
|
92
|
+
if isinstance(self.db_conn, CSVConnection)
|
|
88
93
|
else "SQLite"
|
|
89
|
-
if isinstance(self.db_conn,
|
|
94
|
+
if isinstance(self.db_conn, SQLiteConnection)
|
|
90
95
|
else "database"
|
|
91
96
|
)
|
|
92
97
|
|
sqlsaber/cli/threads.py
CHANGED
|
@@ -38,6 +38,8 @@ def _render_transcript(
|
|
|
38
38
|
from sqlsaber.cli.display import DisplayManager
|
|
39
39
|
|
|
40
40
|
dm = DisplayManager(console)
|
|
41
|
+
# Check if output is being redirected (for clean markdown export)
|
|
42
|
+
is_redirected = not console.is_terminal
|
|
41
43
|
|
|
42
44
|
# Locate indices of user prompts
|
|
43
45
|
user_indices: list[int] = []
|
|
@@ -78,11 +80,17 @@ def _render_transcript(
|
|
|
78
80
|
parts.append(str(seg))
|
|
79
81
|
text = "\n".join([s for s in parts if s]) or None
|
|
80
82
|
if text:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
83
|
+
if is_redirected:
|
|
84
|
+
console.print(f"**User:**\n\n{text}\n")
|
|
85
|
+
else:
|
|
86
|
+
console.print(
|
|
87
|
+
Panel.fit(Markdown(text), title="User", border_style="cyan")
|
|
88
|
+
)
|
|
84
89
|
return
|
|
85
|
-
|
|
90
|
+
if is_redirected:
|
|
91
|
+
console.print("**User:** (no content)\n")
|
|
92
|
+
else:
|
|
93
|
+
console.print(Panel.fit("(no content)", title="User", border_style="cyan"))
|
|
86
94
|
|
|
87
95
|
def _render_response(message: ModelMessage) -> None:
|
|
88
96
|
for part in getattr(message, "parts", []):
|
|
@@ -90,11 +98,14 @@ def _render_transcript(
|
|
|
90
98
|
if kind == "text":
|
|
91
99
|
text = getattr(part, "content", "")
|
|
92
100
|
if isinstance(text, str) and text.strip():
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
101
|
+
if is_redirected:
|
|
102
|
+
console.print(f"**Assistant:**\n\n{text}\n")
|
|
103
|
+
else:
|
|
104
|
+
console.print(
|
|
105
|
+
Panel.fit(
|
|
106
|
+
Markdown(text), title="Assistant", border_style="green"
|
|
107
|
+
)
|
|
96
108
|
)
|
|
97
|
-
)
|
|
98
109
|
elif kind in ("tool-call", "builtin-tool-call"):
|
|
99
110
|
name = getattr(part, "tool_name", "tool")
|
|
100
111
|
args = getattr(part, "args", None)
|
|
@@ -135,6 +146,20 @@ def _render_transcript(
|
|
|
135
146
|
dm.show_sql_error(
|
|
136
147
|
data.get("error"), data.get("suggestions")
|
|
137
148
|
)
|
|
149
|
+
else:
|
|
150
|
+
if is_redirected:
|
|
151
|
+
console.print(f"**Tool result ({name}):**\n\n{content_str}\n")
|
|
152
|
+
else:
|
|
153
|
+
console.print(
|
|
154
|
+
Panel.fit(
|
|
155
|
+
content_str,
|
|
156
|
+
title=f"Tool result: {name}",
|
|
157
|
+
border_style="yellow",
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
except Exception:
|
|
161
|
+
if is_redirected:
|
|
162
|
+
console.print(f"**Tool result ({name}):**\n\n{content_str}\n")
|
|
138
163
|
else:
|
|
139
164
|
console.print(
|
|
140
165
|
Panel.fit(
|
|
@@ -143,7 +168,10 @@ def _render_transcript(
|
|
|
143
168
|
border_style="yellow",
|
|
144
169
|
)
|
|
145
170
|
)
|
|
146
|
-
|
|
171
|
+
else:
|
|
172
|
+
if is_redirected:
|
|
173
|
+
console.print(f"**Tool result ({name}):**\n\n{content_str}\n")
|
|
174
|
+
else:
|
|
147
175
|
console.print(
|
|
148
176
|
Panel.fit(
|
|
149
177
|
content_str,
|
|
@@ -151,14 +179,6 @@ def _render_transcript(
|
|
|
151
179
|
border_style="yellow",
|
|
152
180
|
)
|
|
153
181
|
)
|
|
154
|
-
else:
|
|
155
|
-
console.print(
|
|
156
|
-
Panel.fit(
|
|
157
|
-
content_str,
|
|
158
|
-
title=f"Tool result: {name}",
|
|
159
|
-
border_style="yellow",
|
|
160
|
-
)
|
|
161
|
-
)
|
|
162
182
|
# Thinking parts omitted
|
|
163
183
|
|
|
164
184
|
for start_idx, end_idx in slices or [(0, len(all_msgs))]:
|
|
@@ -270,7 +290,10 @@ def resume(
|
|
|
270
290
|
try:
|
|
271
291
|
agent = build_sqlsaber_agent(db_conn, db_name)
|
|
272
292
|
history = await store.get_thread_messages(thread_id)
|
|
273
|
-
console.
|
|
293
|
+
if console.is_terminal:
|
|
294
|
+
console.print(Panel.fit(f"Thread: {thread.id}", border_style="blue"))
|
|
295
|
+
else:
|
|
296
|
+
console.print(f"# Thread: {thread.id}\n")
|
|
274
297
|
_render_transcript(console, history, None)
|
|
275
298
|
session = InteractiveSession(
|
|
276
299
|
console=console,
|
sqlsaber/config/database.py
CHANGED
|
@@ -18,7 +18,7 @@ class DatabaseConfig:
|
|
|
18
18
|
"""Database connection configuration."""
|
|
19
19
|
|
|
20
20
|
name: str
|
|
21
|
-
type: str # postgresql, mysql, sqlite, csv
|
|
21
|
+
type: str # postgresql, mysql, sqlite, duckdb, csv
|
|
22
22
|
host: str | None
|
|
23
23
|
port: int | None
|
|
24
24
|
database: str
|
|
@@ -90,6 +90,8 @@ class DatabaseConfig:
|
|
|
90
90
|
|
|
91
91
|
elif self.type == "sqlite":
|
|
92
92
|
return f"sqlite:///{self.database}"
|
|
93
|
+
elif self.type == "duckdb":
|
|
94
|
+
return f"duckdb:///{self.database}"
|
|
93
95
|
elif self.type == "csv":
|
|
94
96
|
# For CSV files, database field contains the file path
|
|
95
97
|
base_url = f"csv:///{self.database}"
|