sqlsaber 0.34.0__py3-none-any.whl → 0.36.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/application/db_setup.py +38 -2
- sqlsaber/cli/commands.py +5 -1
- sqlsaber/cli/database.py +160 -12
- sqlsaber/cli/display.py +19 -3
- sqlsaber/cli/interactive.py +6 -2
- sqlsaber/cli/threads.py +4 -2
- sqlsaber/config/database.py +14 -1
- sqlsaber/database/__init__.py +13 -6
- sqlsaber/database/base.py +59 -0
- sqlsaber/database/duckdb.py +68 -30
- sqlsaber/database/mysql.py +37 -10
- sqlsaber/database/postgresql.py +14 -18
- sqlsaber/database/resolver.py +17 -7
- sqlsaber/database/schema.py +3 -0
- sqlsaber/database/sqlite.py +18 -5
- sqlsaber/tools/sql_tools.py +32 -21
- {sqlsaber-0.34.0.dist-info → sqlsaber-0.36.0.dist-info}/METADATA +1 -1
- {sqlsaber-0.34.0.dist-info → sqlsaber-0.36.0.dist-info}/RECORD +21 -21
- {sqlsaber-0.34.0.dist-info → sqlsaber-0.36.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.34.0.dist-info → sqlsaber-0.36.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.34.0.dist-info → sqlsaber-0.36.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/application/db_setup.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Shared database setup logic for onboarding and CLI."""
|
|
2
2
|
|
|
3
3
|
import getpass
|
|
4
|
-
from dataclasses import dataclass
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
7
|
from sqlsaber.application.prompts import Prompter
|
|
@@ -11,6 +11,21 @@ from sqlsaber.theme.manager import create_console
|
|
|
11
11
|
console = create_console()
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
def _normalize_schemas(schemas: list[str]) -> list[str]:
|
|
15
|
+
"""Deduplicate schema list while preserving order and case."""
|
|
16
|
+
normalized: list[str] = []
|
|
17
|
+
seen: set[str] = set()
|
|
18
|
+
for schema in schemas:
|
|
19
|
+
name = schema.strip()
|
|
20
|
+
if not name:
|
|
21
|
+
continue
|
|
22
|
+
if name in seen:
|
|
23
|
+
continue
|
|
24
|
+
seen.add(name)
|
|
25
|
+
normalized.append(name)
|
|
26
|
+
return normalized
|
|
27
|
+
|
|
28
|
+
|
|
14
29
|
@dataclass
|
|
15
30
|
class DatabaseInput:
|
|
16
31
|
"""Input data for database configuration."""
|
|
@@ -26,6 +41,7 @@ class DatabaseInput:
|
|
|
26
41
|
ssl_ca: str | None = None
|
|
27
42
|
ssl_cert: str | None = None
|
|
28
43
|
ssl_key: str | None = None
|
|
44
|
+
exclude_schemas: list[str] = field(default_factory=list)
|
|
29
45
|
|
|
30
46
|
|
|
31
47
|
async def collect_db_input(
|
|
@@ -69,11 +85,20 @@ async def collect_db_input(
|
|
|
69
85
|
port = 0
|
|
70
86
|
username = db_type
|
|
71
87
|
password = ""
|
|
88
|
+
exclude_schemas: list[str] = []
|
|
72
89
|
ssl_mode = None
|
|
73
90
|
ssl_ca = None
|
|
74
91
|
ssl_cert = None
|
|
75
92
|
ssl_key = None
|
|
76
93
|
|
|
94
|
+
if db_type == "duckdb":
|
|
95
|
+
exclude_prompt = await prompter.text(
|
|
96
|
+
"Schemas to exclude (comma separated, optional):", default=""
|
|
97
|
+
)
|
|
98
|
+
if exclude_prompt is None:
|
|
99
|
+
return None
|
|
100
|
+
exclude_schemas = _normalize_schemas(exclude_prompt.split(","))
|
|
101
|
+
|
|
77
102
|
else:
|
|
78
103
|
# PostgreSQL/MySQL need connection details
|
|
79
104
|
host = await prompter.text("Host:", default="localhost")
|
|
@@ -155,6 +180,13 @@ async def collect_db_input(
|
|
|
155
180
|
"SSL client private key file:"
|
|
156
181
|
)
|
|
157
182
|
|
|
183
|
+
exclude_prompt = await prompter.text(
|
|
184
|
+
"Schemas to exclude (comma separated, optional):", default=""
|
|
185
|
+
)
|
|
186
|
+
if exclude_prompt is None:
|
|
187
|
+
return None
|
|
188
|
+
exclude_schemas = _normalize_schemas(exclude_prompt.split(","))
|
|
189
|
+
|
|
158
190
|
return DatabaseInput(
|
|
159
191
|
name=name,
|
|
160
192
|
type=db_type,
|
|
@@ -167,6 +199,7 @@ async def collect_db_input(
|
|
|
167
199
|
ssl_ca=ssl_ca,
|
|
168
200
|
ssl_cert=ssl_cert,
|
|
169
201
|
ssl_key=ssl_key,
|
|
202
|
+
exclude_schemas=exclude_schemas,
|
|
170
203
|
)
|
|
171
204
|
|
|
172
205
|
|
|
@@ -183,6 +216,7 @@ def build_config(db_input: DatabaseInput) -> DatabaseConfig:
|
|
|
183
216
|
ssl_ca=db_input.ssl_ca,
|
|
184
217
|
ssl_cert=db_input.ssl_cert,
|
|
185
218
|
ssl_key=db_input.ssl_key,
|
|
219
|
+
exclude_schemas=_normalize_schemas(db_input.exclude_schemas),
|
|
186
220
|
)
|
|
187
221
|
|
|
188
222
|
|
|
@@ -200,7 +234,9 @@ async def test_connection(config: DatabaseConfig, password: str | None) -> bool:
|
|
|
200
234
|
|
|
201
235
|
try:
|
|
202
236
|
connection_string = config.to_connection_string()
|
|
203
|
-
db_conn = DatabaseConnection(
|
|
237
|
+
db_conn = DatabaseConnection(
|
|
238
|
+
connection_string, excluded_schemas=config.exclude_schemas
|
|
239
|
+
)
|
|
204
240
|
await db_conn.execute_query("SELECT 1 as test")
|
|
205
241
|
await db_conn.close()
|
|
206
242
|
return True
|
sqlsaber/cli/commands.py
CHANGED
|
@@ -214,7 +214,9 @@ def query(
|
|
|
214
214
|
|
|
215
215
|
# Create database connection
|
|
216
216
|
try:
|
|
217
|
-
db_conn = DatabaseConnection(
|
|
217
|
+
db_conn = DatabaseConnection(
|
|
218
|
+
connection_string, excluded_schemas=resolved.excluded_schemas
|
|
219
|
+
)
|
|
218
220
|
log.info("db.connection.created", db_type=type(db_conn).__name__)
|
|
219
221
|
except Exception as e:
|
|
220
222
|
log.exception("db.connection.error", error=str(e))
|
|
@@ -229,8 +231,10 @@ def query(
|
|
|
229
231
|
# Single query mode with streaming
|
|
230
232
|
streaming_handler = StreamingQueryHandler(console)
|
|
231
233
|
db_type = sqlsaber_agent.db_type
|
|
234
|
+
model_name = sqlsaber_agent.agent.model.model_name
|
|
232
235
|
console.print(
|
|
233
236
|
f"[primary]Connected to:[/primary] {db_name} ({db_type})\n"
|
|
237
|
+
f"[primary]Model:[/primary] {model_name}\n"
|
|
234
238
|
)
|
|
235
239
|
log.info("query.execute.start", db_name=db_name, db_type=db_type)
|
|
236
240
|
run = await streaming_handler.execute_streaming_query(
|
sqlsaber/cli/database.py
CHANGED
|
@@ -26,6 +26,28 @@ db_app = cyclopts.App(
|
|
|
26
26
|
)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
def _normalize_schema_list(raw_schemas: list[str]) -> list[str]:
|
|
30
|
+
"""Deduplicate schemas while preserving order and case."""
|
|
31
|
+
schemas: list[str] = []
|
|
32
|
+
seen: set[str] = set()
|
|
33
|
+
for schema in raw_schemas:
|
|
34
|
+
item = schema.strip()
|
|
35
|
+
if not item:
|
|
36
|
+
continue
|
|
37
|
+
if item in seen:
|
|
38
|
+
continue
|
|
39
|
+
seen.add(item)
|
|
40
|
+
schemas.append(item)
|
|
41
|
+
return schemas
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _parse_schema_list(raw: str | None) -> list[str]:
|
|
45
|
+
"""Parse comma-separated schema list into cleaned list."""
|
|
46
|
+
if not raw:
|
|
47
|
+
return []
|
|
48
|
+
return _normalize_schema_list(raw.split(","))
|
|
49
|
+
|
|
50
|
+
|
|
29
51
|
@db_app.command
|
|
30
52
|
def add(
|
|
31
53
|
name: Annotated[str, cyclopts.Parameter(help="Name for the database connection")],
|
|
@@ -71,6 +93,13 @@ def add(
|
|
|
71
93
|
str | None,
|
|
72
94
|
cyclopts.Parameter(["--ssl-key"], help="SSL client private key file path"),
|
|
73
95
|
] = None,
|
|
96
|
+
exclude_schemas: Annotated[
|
|
97
|
+
str | None,
|
|
98
|
+
cyclopts.Parameter(
|
|
99
|
+
["--exclude-schemas"],
|
|
100
|
+
help="Comma-separated list of schemas to exclude from introspection",
|
|
101
|
+
),
|
|
102
|
+
] = None,
|
|
74
103
|
interactive: Annotated[
|
|
75
104
|
bool,
|
|
76
105
|
cyclopts.Parameter(
|
|
@@ -119,6 +148,7 @@ def add(
|
|
|
119
148
|
ssl_ca = db_input.ssl_ca
|
|
120
149
|
ssl_cert = db_input.ssl_cert
|
|
121
150
|
ssl_key = db_input.ssl_key
|
|
151
|
+
exclude_schema_list = _normalize_schema_list(db_input.exclude_schemas)
|
|
122
152
|
else:
|
|
123
153
|
# Non-interactive mode - use provided values or defaults
|
|
124
154
|
if type == "sqlite":
|
|
@@ -160,6 +190,7 @@ def add(
|
|
|
160
190
|
if questionary.confirm("Enter password?").ask()
|
|
161
191
|
else ""
|
|
162
192
|
)
|
|
193
|
+
exclude_schema_list = _parse_schema_list(exclude_schemas)
|
|
163
194
|
|
|
164
195
|
# Create database config
|
|
165
196
|
# At this point, all required values should be set
|
|
@@ -180,6 +211,7 @@ def add(
|
|
|
180
211
|
ssl_ca=ssl_ca,
|
|
181
212
|
ssl_cert=ssl_cert,
|
|
182
213
|
ssl_key=ssl_key,
|
|
214
|
+
exclude_schemas=exclude_schema_list,
|
|
183
215
|
)
|
|
184
216
|
|
|
185
217
|
try:
|
|
@@ -219,6 +251,7 @@ def list():
|
|
|
219
251
|
table.add_column("Port", style="warning")
|
|
220
252
|
table.add_column("Database", style="info")
|
|
221
253
|
table.add_column("Username", style="info")
|
|
254
|
+
table.add_column("Excluded Schemas", style="muted")
|
|
222
255
|
table.add_column("SSL", style="success")
|
|
223
256
|
table.add_column("Default", style="error")
|
|
224
257
|
|
|
@@ -241,6 +274,7 @@ def list():
|
|
|
241
274
|
str(db.port) if db.port else "",
|
|
242
275
|
db.database,
|
|
243
276
|
db.username,
|
|
277
|
+
", ".join(db.exclude_schemas) if db.exclude_schemas else "",
|
|
244
278
|
ssl_status,
|
|
245
279
|
is_default,
|
|
246
280
|
)
|
|
@@ -249,6 +283,116 @@ def list():
|
|
|
249
283
|
logger.info("db.list.complete", count=len(databases))
|
|
250
284
|
|
|
251
285
|
|
|
286
|
+
@db_app.command
|
|
287
|
+
def exclude(
|
|
288
|
+
name: Annotated[
|
|
289
|
+
str,
|
|
290
|
+
cyclopts.Parameter(help="Name of the database connection to update"),
|
|
291
|
+
],
|
|
292
|
+
set_schemas: Annotated[
|
|
293
|
+
str | None,
|
|
294
|
+
cyclopts.Parameter(
|
|
295
|
+
["--set"],
|
|
296
|
+
help="Replace excluded schemas with this comma-separated list",
|
|
297
|
+
),
|
|
298
|
+
] = None,
|
|
299
|
+
add_schemas: Annotated[
|
|
300
|
+
str | None,
|
|
301
|
+
cyclopts.Parameter(
|
|
302
|
+
["--add"],
|
|
303
|
+
help="Add comma-separated schemas to the existing exclude list",
|
|
304
|
+
),
|
|
305
|
+
] = None,
|
|
306
|
+
remove_schemas: Annotated[
|
|
307
|
+
str | None,
|
|
308
|
+
cyclopts.Parameter(
|
|
309
|
+
["--remove"],
|
|
310
|
+
help="Remove comma-separated schemas from the existing exclude list",
|
|
311
|
+
),
|
|
312
|
+
] = None,
|
|
313
|
+
clear: Annotated[
|
|
314
|
+
bool,
|
|
315
|
+
cyclopts.Parameter(
|
|
316
|
+
["--clear", "--no-clear"],
|
|
317
|
+
help="Clear all excluded schemas",
|
|
318
|
+
),
|
|
319
|
+
] = False,
|
|
320
|
+
):
|
|
321
|
+
"""Update excluded schemas for a database connection."""
|
|
322
|
+
logger.info(
|
|
323
|
+
"db.exclude.start",
|
|
324
|
+
name=name,
|
|
325
|
+
set=bool(set_schemas),
|
|
326
|
+
add=bool(add_schemas),
|
|
327
|
+
remove=bool(remove_schemas),
|
|
328
|
+
clear=clear,
|
|
329
|
+
)
|
|
330
|
+
db_config = config_manager.get_database(name)
|
|
331
|
+
if not db_config:
|
|
332
|
+
console.print(
|
|
333
|
+
f"[bold error]Error: Database connection '{name}' not found[/bold error]"
|
|
334
|
+
)
|
|
335
|
+
logger.error("db.exclude.not_found", name=name)
|
|
336
|
+
sys.exit(1)
|
|
337
|
+
|
|
338
|
+
actions_selected = sum(
|
|
339
|
+
bool(flag)
|
|
340
|
+
for flag in [
|
|
341
|
+
set_schemas is not None,
|
|
342
|
+
add_schemas is not None,
|
|
343
|
+
remove_schemas is not None,
|
|
344
|
+
clear,
|
|
345
|
+
]
|
|
346
|
+
)
|
|
347
|
+
if actions_selected > 1:
|
|
348
|
+
console.print(
|
|
349
|
+
"[bold error]Error: Specify only one of --set, --add, --remove, or --clear[/bold error]"
|
|
350
|
+
)
|
|
351
|
+
logger.error("db.exclude.multiple_actions", name=name)
|
|
352
|
+
sys.exit(1)
|
|
353
|
+
|
|
354
|
+
current = [*(db_config.exclude_schemas or [])]
|
|
355
|
+
|
|
356
|
+
if clear:
|
|
357
|
+
updated = []
|
|
358
|
+
elif set_schemas is not None:
|
|
359
|
+
updated = _parse_schema_list(set_schemas)
|
|
360
|
+
elif add_schemas is not None:
|
|
361
|
+
additions = _parse_schema_list(add_schemas)
|
|
362
|
+
updated = [*current]
|
|
363
|
+
current_set = set(current)
|
|
364
|
+
for schema in additions:
|
|
365
|
+
if schema not in current_set:
|
|
366
|
+
updated.append(schema)
|
|
367
|
+
current_set.add(schema)
|
|
368
|
+
elif remove_schemas is not None:
|
|
369
|
+
removals = set(_parse_schema_list(remove_schemas))
|
|
370
|
+
updated = [schema for schema in current if schema not in removals]
|
|
371
|
+
else:
|
|
372
|
+
console.print(
|
|
373
|
+
"[info]Update excluded schemas for "
|
|
374
|
+
f"[primary]{name}[/primary] (leave blank to clear)[/info]"
|
|
375
|
+
)
|
|
376
|
+
default_value = ", ".join(current)
|
|
377
|
+
response = questionary.text(
|
|
378
|
+
"Schemas to exclude (comma separated):", default=default_value
|
|
379
|
+
).ask()
|
|
380
|
+
if response is None:
|
|
381
|
+
console.print("[warning]Operation cancelled[/warning]")
|
|
382
|
+
logger.info("db.exclude.cancelled", name=name)
|
|
383
|
+
return
|
|
384
|
+
updated = _parse_schema_list(response)
|
|
385
|
+
|
|
386
|
+
db_config.exclude_schemas = _normalize_schema_list(updated)
|
|
387
|
+
config_manager.update_database(db_config)
|
|
388
|
+
|
|
389
|
+
console.print(
|
|
390
|
+
f"[success]Updated excluded schemas for '{name}':[/success] "
|
|
391
|
+
f"{', '.join(db_config.exclude_schemas) if db_config.exclude_schemas else '(none)'}"
|
|
392
|
+
)
|
|
393
|
+
logger.info("db.exclude.success", name=name, count=len(db_config.exclude_schemas))
|
|
394
|
+
|
|
395
|
+
|
|
252
396
|
@db_app.command
|
|
253
397
|
def remove(
|
|
254
398
|
name: Annotated[
|
|
@@ -259,7 +403,7 @@ def remove(
|
|
|
259
403
|
logger.info("db.remove.start", name=name)
|
|
260
404
|
if not config_manager.get_database(name):
|
|
261
405
|
console.print(
|
|
262
|
-
f"[bold error]Error:
|
|
406
|
+
f"[bold error]Error: Database connection '{name}' not found[/bold error]"
|
|
263
407
|
)
|
|
264
408
|
logger.error("db.remove.not_found", name=name)
|
|
265
409
|
sys.exit(1)
|
|
@@ -269,17 +413,17 @@ def remove(
|
|
|
269
413
|
).ask():
|
|
270
414
|
if config_manager.remove_database(name):
|
|
271
415
|
console.print(
|
|
272
|
-
f"[
|
|
416
|
+
f"[success]Successfully removed database connection '{name}'[/success]"
|
|
273
417
|
)
|
|
274
418
|
logger.info("db.remove.success", name=name)
|
|
275
419
|
else:
|
|
276
420
|
console.print(
|
|
277
|
-
f"[bold error]Error:
|
|
421
|
+
f"[bold error]Error: Failed to remove database connection '{name}'[/bold error]"
|
|
278
422
|
)
|
|
279
423
|
logger.error("db.remove.failed", name=name)
|
|
280
424
|
sys.exit(1)
|
|
281
425
|
else:
|
|
282
|
-
console.print("Operation cancelled")
|
|
426
|
+
console.print("[warning]Operation cancelled[/warning]")
|
|
283
427
|
logger.info("db.remove.cancelled", name=name)
|
|
284
428
|
|
|
285
429
|
|
|
@@ -294,17 +438,19 @@ def set_default(
|
|
|
294
438
|
logger.info("db.default.start", name=name)
|
|
295
439
|
if not config_manager.get_database(name):
|
|
296
440
|
console.print(
|
|
297
|
-
f"[bold error]Error:
|
|
441
|
+
f"[bold error]Error: Database connection '{name}' not found[/bold error]"
|
|
298
442
|
)
|
|
299
443
|
logger.error("db.default.not_found", name=name)
|
|
300
444
|
sys.exit(1)
|
|
301
445
|
|
|
302
446
|
if config_manager.set_default_database(name):
|
|
303
|
-
console.print(
|
|
447
|
+
console.print(
|
|
448
|
+
f"[success]Successfully set '{name}' as default database[/success]"
|
|
449
|
+
)
|
|
304
450
|
logger.info("db.default.success", name=name)
|
|
305
451
|
else:
|
|
306
452
|
console.print(
|
|
307
|
-
f"[bold error]Error:
|
|
453
|
+
f"[bold error]Error: Failed to set '{name}' as default[/bold error]"
|
|
308
454
|
)
|
|
309
455
|
logger.error("db.default.failed", name=name)
|
|
310
456
|
sys.exit(1)
|
|
@@ -330,7 +476,7 @@ def test(
|
|
|
330
476
|
db_config = config_manager.get_database(name)
|
|
331
477
|
if not db_config:
|
|
332
478
|
console.print(
|
|
333
|
-
f"[bold error]Error:
|
|
479
|
+
f"[bold error]Error: Database connection '{name}' not found[/bold error]"
|
|
334
480
|
)
|
|
335
481
|
logger.error("db.test.not_found", name=name)
|
|
336
482
|
sys.exit(1)
|
|
@@ -338,7 +484,7 @@ def test(
|
|
|
338
484
|
db_config = config_manager.get_default_database()
|
|
339
485
|
if not db_config:
|
|
340
486
|
console.print(
|
|
341
|
-
"[bold error]Error:
|
|
487
|
+
"[bold error]Error: No default database configured[/bold error]"
|
|
342
488
|
)
|
|
343
489
|
console.print(
|
|
344
490
|
"Use 'sqlsaber db add <name>' to add a database connection"
|
|
@@ -350,14 +496,16 @@ def test(
|
|
|
350
496
|
|
|
351
497
|
try:
|
|
352
498
|
connection_string = db_config.to_connection_string()
|
|
353
|
-
db_conn = DatabaseConnection(
|
|
499
|
+
db_conn = DatabaseConnection(
|
|
500
|
+
connection_string, excluded_schemas=db_config.exclude_schemas
|
|
501
|
+
)
|
|
354
502
|
|
|
355
503
|
# Try to connect and run a simple query
|
|
356
504
|
await db_conn.execute_query("SELECT 1 as test")
|
|
357
505
|
await db_conn.close()
|
|
358
506
|
|
|
359
507
|
console.print(
|
|
360
|
-
f"[
|
|
508
|
+
f"[success]✓ Connection to '{db_config.name}' successful[/success]"
|
|
361
509
|
)
|
|
362
510
|
logger.info("db.test.success", name=db_config.name)
|
|
363
511
|
|
|
@@ -369,7 +517,7 @@ def test(
|
|
|
369
517
|
),
|
|
370
518
|
error=str(e),
|
|
371
519
|
)
|
|
372
|
-
console.print(f"[bold error]✗ Connection failed:[/bold error]
|
|
520
|
+
console.print(f"[bold error]✗ Connection failed: {e}[/bold error]")
|
|
373
521
|
sys.exit(1)
|
|
374
522
|
|
|
375
523
|
asyncio.run(test_connection())
|
sqlsaber/cli/display.py
CHANGED
|
@@ -406,9 +406,17 @@ class DisplayManager:
|
|
|
406
406
|
for table_name, table_info in data.items():
|
|
407
407
|
self.console.print(f"\n[heading]Table: {table_name}[/heading]")
|
|
408
408
|
|
|
409
|
+
table_comment = table_info.get("comment")
|
|
410
|
+
if table_comment:
|
|
411
|
+
self.console.print(f"[muted]Comment: {table_comment}[/muted]")
|
|
412
|
+
|
|
409
413
|
# Show columns
|
|
410
414
|
table_columns = table_info.get("columns", {})
|
|
411
415
|
if table_columns:
|
|
416
|
+
include_column_comments = any(
|
|
417
|
+
col_info.get("comment") for col_info in table_columns.values()
|
|
418
|
+
)
|
|
419
|
+
|
|
412
420
|
# Create a table for columns
|
|
413
421
|
columns = [
|
|
414
422
|
{"name": "Column Name", "style": "column.name"},
|
|
@@ -416,6 +424,8 @@ class DisplayManager:
|
|
|
416
424
|
{"name": "Nullable", "style": "info"},
|
|
417
425
|
{"name": "Default", "style": "muted"},
|
|
418
426
|
]
|
|
427
|
+
if include_column_comments:
|
|
428
|
+
columns.append({"name": "Comment", "style": "muted"})
|
|
419
429
|
col_table = self._create_table(columns, title="Columns")
|
|
420
430
|
|
|
421
431
|
for col_name, col_info in table_columns.items():
|
|
@@ -425,9 +435,15 @@ class DisplayManager:
|
|
|
425
435
|
if col_info.get("default")
|
|
426
436
|
else ""
|
|
427
437
|
)
|
|
428
|
-
|
|
429
|
-
col_name,
|
|
430
|
-
|
|
438
|
+
row = [
|
|
439
|
+
col_name,
|
|
440
|
+
col_info.get("type", ""),
|
|
441
|
+
nullable,
|
|
442
|
+
default,
|
|
443
|
+
]
|
|
444
|
+
if include_column_comments:
|
|
445
|
+
row.append(col_info.get("comment") or "")
|
|
446
|
+
col_table.add_row(*row)
|
|
431
447
|
|
|
432
448
|
self.console.print(col_table)
|
|
433
449
|
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -20,6 +20,7 @@ from sqlsaber.cli.completers import (
|
|
|
20
20
|
)
|
|
21
21
|
from sqlsaber.cli.display import DisplayManager
|
|
22
22
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
23
|
+
from sqlsaber.config.logging import get_logger
|
|
23
24
|
from sqlsaber.database import (
|
|
24
25
|
CSVConnection,
|
|
25
26
|
DuckDBConnection,
|
|
@@ -30,7 +31,6 @@ from sqlsaber.database import (
|
|
|
30
31
|
from sqlsaber.database.schema import SchemaManager
|
|
31
32
|
from sqlsaber.theme.manager import get_theme_manager
|
|
32
33
|
from sqlsaber.threads import ThreadStorage
|
|
33
|
-
from sqlsaber.config.logging import get_logger
|
|
34
34
|
|
|
35
35
|
if TYPE_CHECKING:
|
|
36
36
|
from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
|
|
@@ -135,8 +135,10 @@ class InteractiveSession:
|
|
|
135
135
|
)
|
|
136
136
|
|
|
137
137
|
db_name = self.database_name or "Unknown"
|
|
138
|
+
model_name = self.sqlsaber_agent.agent.model.model_name
|
|
138
139
|
self.console.print(
|
|
139
140
|
f"[heading]\nConnected to {db_name} ({self._db_type_name()})[/heading]\n"
|
|
141
|
+
f"[heading]Model: {model_name}[/heading]\n"
|
|
140
142
|
)
|
|
141
143
|
|
|
142
144
|
if self._thread_id:
|
|
@@ -309,6 +311,8 @@ class InteractiveSession:
|
|
|
309
311
|
style=self.tm.pt_style(),
|
|
310
312
|
)
|
|
311
313
|
|
|
314
|
+
user_query = user_query.strip()
|
|
315
|
+
|
|
312
316
|
if not user_query:
|
|
313
317
|
continue
|
|
314
318
|
|
|
@@ -325,7 +329,7 @@ class InteractiveSession:
|
|
|
325
329
|
|
|
326
330
|
# Handle memory addition
|
|
327
331
|
if user_query.strip().startswith("#"):
|
|
328
|
-
await self._handle_memory(user_query
|
|
332
|
+
await self._handle_memory(user_query[1:].strip())
|
|
329
333
|
continue
|
|
330
334
|
|
|
331
335
|
# Execute query with cancellation support
|
sqlsaber/cli/threads.py
CHANGED
|
@@ -229,7 +229,7 @@ def list_threads(
|
|
|
229
229
|
logger.info("threads.cli.list.empty")
|
|
230
230
|
return
|
|
231
231
|
table = Table(title="Threads")
|
|
232
|
-
table.add_column("ID", style=tm.style("info"))
|
|
232
|
+
table.add_column("ID", style=tm.style("info"), no_wrap=True, min_width=36)
|
|
233
233
|
table.add_column("Database", style=tm.style("accent"))
|
|
234
234
|
table.add_column("Title", style=tm.style("success"))
|
|
235
235
|
table.add_column("Last Activity", style=tm.style("muted"))
|
|
@@ -318,7 +318,9 @@ def resume(
|
|
|
318
318
|
)
|
|
319
319
|
return
|
|
320
320
|
|
|
321
|
-
db_conn = DatabaseConnection(
|
|
321
|
+
db_conn = DatabaseConnection(
|
|
322
|
+
connection_string, excluded_schemas=resolved.excluded_schemas
|
|
323
|
+
)
|
|
322
324
|
try:
|
|
323
325
|
sqlsaber_agent = SQLSaberAgent(db_conn, db_name)
|
|
324
326
|
history = await store.get_thread_messages(thread_id)
|
sqlsaber/config/database.py
CHANGED
|
@@ -4,7 +4,7 @@ import json
|
|
|
4
4
|
import os
|
|
5
5
|
import platform
|
|
6
6
|
import stat
|
|
7
|
-
from dataclasses import dataclass
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
from typing import Any
|
|
10
10
|
from urllib.parse import quote_plus
|
|
@@ -29,6 +29,7 @@ class DatabaseConfig:
|
|
|
29
29
|
ssl_cert: str | None = None
|
|
30
30
|
ssl_key: str | None = None
|
|
31
31
|
schema: str | None = None
|
|
32
|
+
exclude_schemas: list[str] = field(default_factory=list)
|
|
32
33
|
|
|
33
34
|
def to_connection_string(self) -> str:
|
|
34
35
|
"""Convert config to database connection string."""
|
|
@@ -149,6 +150,7 @@ class DatabaseConfig:
|
|
|
149
150
|
"ssl_cert": self.ssl_cert,
|
|
150
151
|
"ssl_key": self.ssl_key,
|
|
151
152
|
"schema": self.schema,
|
|
153
|
+
"exclude_schemas": self.exclude_schemas,
|
|
152
154
|
}
|
|
153
155
|
|
|
154
156
|
@classmethod
|
|
@@ -166,6 +168,7 @@ class DatabaseConfig:
|
|
|
166
168
|
ssl_cert=data.get("ssl_cert"),
|
|
167
169
|
ssl_key=data.get("ssl_key"),
|
|
168
170
|
schema=data.get("schema"),
|
|
171
|
+
exclude_schemas=list(data.get("exclude_schemas", [])),
|
|
169
172
|
)
|
|
170
173
|
|
|
171
174
|
|
|
@@ -246,6 +249,16 @@ class DatabaseConfigManager:
|
|
|
246
249
|
|
|
247
250
|
self._save_config(config)
|
|
248
251
|
|
|
252
|
+
def update_database(self, db_config: DatabaseConfig) -> None:
|
|
253
|
+
"""Update an existing database configuration."""
|
|
254
|
+
config = self._load_config()
|
|
255
|
+
|
|
256
|
+
if db_config.name not in config["connections"]:
|
|
257
|
+
raise ValueError(f"Database '{db_config.name}' does not exist")
|
|
258
|
+
|
|
259
|
+
config["connections"][db_config.name] = db_config.to_dict()
|
|
260
|
+
self._save_config(config)
|
|
261
|
+
|
|
249
262
|
def get_database(self, name: str) -> DatabaseConfig | None:
|
|
250
263
|
"""Get a database configuration by name."""
|
|
251
264
|
config = self._load_config()
|
sqlsaber/database/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Database module for SQLSaber."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
|
|
3
5
|
from .base import (
|
|
4
6
|
DEFAULT_QUERY_TIMEOUT,
|
|
5
7
|
BaseDatabaseConnection,
|
|
@@ -18,23 +20,28 @@ from .schema import SchemaManager
|
|
|
18
20
|
from .sqlite import SQLiteConnection, SQLiteSchemaIntrospector
|
|
19
21
|
|
|
20
22
|
|
|
21
|
-
def DatabaseConnection(
|
|
23
|
+
def DatabaseConnection(
|
|
24
|
+
connection_string: str, *, excluded_schemas: Iterable[str] | None = None
|
|
25
|
+
) -> BaseDatabaseConnection:
|
|
22
26
|
"""Factory function to create appropriate database connection based on connection string."""
|
|
23
27
|
if connection_string.startswith("postgresql://"):
|
|
24
|
-
|
|
28
|
+
conn = PostgreSQLConnection(connection_string)
|
|
25
29
|
elif connection_string.startswith("mysql://"):
|
|
26
|
-
|
|
30
|
+
conn = MySQLConnection(connection_string)
|
|
27
31
|
elif connection_string.startswith("sqlite:///"):
|
|
28
|
-
|
|
32
|
+
conn = SQLiteConnection(connection_string)
|
|
29
33
|
elif connection_string.startswith("duckdb://"):
|
|
30
|
-
|
|
34
|
+
conn = DuckDBConnection(connection_string)
|
|
31
35
|
elif connection_string.startswith("csv:///"):
|
|
32
|
-
|
|
36
|
+
conn = CSVConnection(connection_string)
|
|
33
37
|
else:
|
|
34
38
|
raise ValueError(
|
|
35
39
|
f"Unsupported database type in connection string: {connection_string}"
|
|
36
40
|
)
|
|
37
41
|
|
|
42
|
+
conn.set_excluded_schemas(excluded_schemas)
|
|
43
|
+
return conn
|
|
44
|
+
|
|
38
45
|
|
|
39
46
|
__all__ = [
|
|
40
47
|
# Base classes and types
|