sqlsaber 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

Files changed (38) hide show
  1. sqlsaber/agents/__init__.py +2 -2
  2. sqlsaber/agents/base.py +1 -1
  3. sqlsaber/agents/mcp.py +1 -1
  4. sqlsaber/agents/pydantic_ai_agent.py +207 -135
  5. sqlsaber/application/__init__.py +1 -0
  6. sqlsaber/application/auth_setup.py +164 -0
  7. sqlsaber/application/db_setup.py +223 -0
  8. sqlsaber/application/model_selection.py +98 -0
  9. sqlsaber/application/prompts.py +115 -0
  10. sqlsaber/cli/auth.py +22 -50
  11. sqlsaber/cli/commands.py +22 -28
  12. sqlsaber/cli/completers.py +2 -0
  13. sqlsaber/cli/database.py +25 -86
  14. sqlsaber/cli/display.py +29 -9
  15. sqlsaber/cli/interactive.py +150 -127
  16. sqlsaber/cli/models.py +18 -28
  17. sqlsaber/cli/onboarding.py +325 -0
  18. sqlsaber/cli/streaming.py +15 -17
  19. sqlsaber/cli/threads.py +10 -6
  20. sqlsaber/config/api_keys.py +2 -2
  21. sqlsaber/config/settings.py +25 -2
  22. sqlsaber/database/__init__.py +55 -1
  23. sqlsaber/database/base.py +124 -0
  24. sqlsaber/database/csv.py +133 -0
  25. sqlsaber/database/duckdb.py +313 -0
  26. sqlsaber/database/mysql.py +345 -0
  27. sqlsaber/database/postgresql.py +328 -0
  28. sqlsaber/database/schema.py +66 -963
  29. sqlsaber/database/sqlite.py +258 -0
  30. sqlsaber/mcp/mcp.py +1 -1
  31. sqlsaber/tools/sql_tools.py +1 -1
  32. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/METADATA +43 -9
  33. sqlsaber-0.27.0.dist-info/RECORD +58 -0
  34. sqlsaber/database/connection.py +0 -535
  35. sqlsaber-0.25.0.dist-info/RECORD +0 -47
  36. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/WHEEL +0 -0
  37. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/entry_points.txt +0 -0
  38. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,124 @@
1
+ """Base classes and type definitions for database connections and schema introspection."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, TypedDict
5
+
6
+ # Default query timeout to prevent runaway queries
7
+ DEFAULT_QUERY_TIMEOUT = 30.0 # seconds
8
+
9
+
10
+ class QueryTimeoutError(RuntimeError):
11
+ """Exception raised when a query exceeds its timeout."""
12
+
13
+ def __init__(self, seconds: float):
14
+ self.timeout = seconds
15
+ super().__init__(f"Query exceeded timeout of {seconds}s")
16
+
17
+
18
+ class ColumnInfo(TypedDict):
19
+ """Type definition for column information."""
20
+
21
+ data_type: str
22
+ nullable: bool
23
+ default: str | None
24
+ max_length: int | None
25
+ precision: int | None
26
+ scale: int | None
27
+
28
+
29
+ class ForeignKeyInfo(TypedDict):
30
+ """Type definition for foreign key information."""
31
+
32
+ column: str
33
+ references: dict[str, str] # {"table": "schema.table", "column": "column_name"}
34
+
35
+
36
+ class IndexInfo(TypedDict):
37
+ """Type definition for index information."""
38
+
39
+ name: str
40
+ columns: list[str] # ordered
41
+ unique: bool
42
+ type: str | None # btree, gin, FULLTEXT, etc. None if unknown
43
+
44
+
45
+ class SchemaInfo(TypedDict):
46
+ """Type definition for schema information."""
47
+
48
+ schema: str
49
+ name: str
50
+ type: str
51
+ columns: dict[str, ColumnInfo]
52
+ primary_keys: list[str]
53
+ foreign_keys: list[ForeignKeyInfo]
54
+ indexes: list[IndexInfo]
55
+
56
+
57
+ class BaseDatabaseConnection(ABC):
58
+ """Abstract base class for database connections."""
59
+
60
+ def __init__(self, connection_string: str):
61
+ self.connection_string = connection_string
62
+ self._pool = None
63
+
64
+ @abstractmethod
65
+ async def get_pool(self):
66
+ """Get or create connection pool."""
67
+ pass
68
+
69
+ @abstractmethod
70
+ async def close(self):
71
+ """Close the connection pool."""
72
+ pass
73
+
74
+ @abstractmethod
75
+ async def execute_query(
76
+ self, query: str, *args, timeout: float | None = None
77
+ ) -> list[dict[str, Any]]:
78
+ """Execute a query and return results as list of dicts.
79
+
80
+ All queries run in a transaction that is rolled back at the end,
81
+ ensuring no changes are persisted to the database.
82
+
83
+ Args:
84
+ query: SQL query to execute
85
+ *args: Query parameters
86
+ timeout: Query timeout in seconds (overrides default_timeout)
87
+ """
88
+ pass
89
+
90
+
91
+ class BaseSchemaIntrospector(ABC):
92
+ """Abstract base class for database-specific schema introspection."""
93
+
94
+ @abstractmethod
95
+ async def get_tables_info(
96
+ self, connection, table_pattern: str | None = None
97
+ ) -> dict[str, Any]:
98
+ """Get tables information for the specific database type."""
99
+ pass
100
+
101
+ @abstractmethod
102
+ async def get_columns_info(self, connection, tables: list) -> list:
103
+ """Get columns information for the specific database type."""
104
+ pass
105
+
106
+ @abstractmethod
107
+ async def get_foreign_keys_info(self, connection, tables: list) -> list:
108
+ """Get foreign keys information for the specific database type."""
109
+ pass
110
+
111
+ @abstractmethod
112
+ async def get_primary_keys_info(self, connection, tables: list) -> list:
113
+ """Get primary keys information for the specific database type."""
114
+ pass
115
+
116
+ @abstractmethod
117
+ async def get_indexes_info(self, connection, tables: list) -> list:
118
+ """Get indexes information for the specific database type."""
119
+ pass
120
+
121
+ @abstractmethod
122
+ async def list_tables_info(self, connection) -> list[dict[str, Any]]:
123
+ """Get list of tables with basic information."""
124
+ pass
@@ -0,0 +1,133 @@
1
+ """CSV database connection using DuckDB backend."""
2
+
3
+ import asyncio
4
+ from pathlib import Path
5
+ from typing import Any
6
+ from urllib.parse import parse_qs, urlparse
7
+
8
+ import duckdb
9
+
10
+ from .base import DEFAULT_QUERY_TIMEOUT, BaseDatabaseConnection, QueryTimeoutError
11
+ from .duckdb import DuckDBSchemaIntrospector
12
+
13
+
14
+ def _execute_duckdb_transaction(
15
+ conn: duckdb.DuckDBPyConnection, query: str, args: tuple[Any, ...]
16
+ ) -> list[dict[str, Any]]:
17
+ """Run a DuckDB query inside a transaction and return list of dicts."""
18
+ conn.execute("BEGIN TRANSACTION")
19
+ try:
20
+ if args:
21
+ conn.execute(query, args)
22
+ else:
23
+ conn.execute(query)
24
+
25
+ if conn.description is None:
26
+ rows: list[dict[str, Any]] = []
27
+ else:
28
+ columns = [col[0] for col in conn.description]
29
+ data = conn.fetchall()
30
+ rows = [dict(zip(columns, row)) for row in data]
31
+
32
+ conn.execute("ROLLBACK")
33
+ return rows
34
+ except Exception:
35
+ conn.execute("ROLLBACK")
36
+ raise
37
+
38
+
39
+ class CSVConnection(BaseDatabaseConnection):
40
+ """CSV file connection using DuckDB per query."""
41
+
42
+ def __init__(self, connection_string: str):
43
+ super().__init__(connection_string)
44
+
45
+ raw_path = connection_string.replace("csv:///", "", 1)
46
+ self.csv_path = raw_path.split("?", 1)[0]
47
+
48
+ self.delimiter = ","
49
+ self.encoding = "utf-8"
50
+ self.has_header = True
51
+
52
+ parsed = urlparse(connection_string)
53
+ if parsed.query:
54
+ params = parse_qs(parsed.query)
55
+ self.delimiter = params.get("delimiter", [self.delimiter])[0]
56
+ self.encoding = params.get("encoding", [self.encoding])[0]
57
+ self.has_header = params.get("header", ["true"])[0].lower() == "true"
58
+
59
+ self.table_name = Path(self.csv_path).stem or "csv_table"
60
+
61
+ async def get_pool(self):
62
+ """CSV connections do not maintain a pool."""
63
+ return None
64
+
65
+ async def close(self):
66
+ """No persistent resources to close for CSV connections."""
67
+ pass
68
+
69
+ def _quote_identifier(self, identifier: str) -> str:
70
+ escaped = identifier.replace('"', '""')
71
+ return f'"{escaped}"'
72
+
73
+ def _quote_literal(self, value: str) -> str:
74
+ escaped = value.replace("'", "''")
75
+ return f"'{escaped}'"
76
+
77
+ def _normalized_encoding(self) -> str | None:
78
+ encoding = (self.encoding or "").strip()
79
+ if not encoding or encoding.lower() == "utf-8":
80
+ return None
81
+ return encoding.replace("-", "").replace("_", "").upper()
82
+
83
+ def _create_view(self, conn: duckdb.DuckDBPyConnection) -> None:
84
+ header_literal = "TRUE" if self.has_header else "FALSE"
85
+ option_parts = [f"HEADER={header_literal}"]
86
+
87
+ if self.delimiter:
88
+ option_parts.append(f"DELIM={self._quote_literal(self.delimiter)}")
89
+
90
+ encoding = self._normalized_encoding()
91
+ if encoding:
92
+ option_parts.append(f"ENCODING={self._quote_literal(encoding)}")
93
+
94
+ options_sql = ""
95
+ if option_parts:
96
+ options_sql = ", " + ", ".join(option_parts)
97
+
98
+ base_relation_sql = (
99
+ f"read_csv_auto({self._quote_literal(self.csv_path)}{options_sql})"
100
+ )
101
+
102
+ create_view_sql = (
103
+ f"CREATE VIEW {self._quote_identifier(self.table_name)} AS "
104
+ f"SELECT * FROM {base_relation_sql}"
105
+ )
106
+ conn.execute(create_view_sql)
107
+
108
+ async def execute_query(
109
+ self, query: str, *args, timeout: float | None = None
110
+ ) -> list[dict[str, Any]]:
111
+ effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
112
+ args_tuple = tuple(args) if args else tuple()
113
+
114
+ def _run_query() -> list[dict[str, Any]]:
115
+ conn = duckdb.connect(":memory:")
116
+ try:
117
+ self._create_view(conn)
118
+ return _execute_duckdb_transaction(conn, query, args_tuple)
119
+ finally:
120
+ conn.close()
121
+
122
+ try:
123
+ return await asyncio.wait_for(
124
+ asyncio.to_thread(_run_query), timeout=effective_timeout
125
+ )
126
+ except asyncio.TimeoutError as exc:
127
+ raise QueryTimeoutError(effective_timeout or 0) from exc
128
+
129
+
130
+ class CSVSchemaIntrospector(DuckDBSchemaIntrospector):
131
+ """CSV-specific schema introspection using DuckDB backend."""
132
+
133
+ pass
@@ -0,0 +1,313 @@
1
+ """DuckDB database connection and schema introspection."""
2
+
3
+ import asyncio
4
+ from typing import Any
5
+
6
+ import duckdb
7
+
8
+ from .base import (
9
+ DEFAULT_QUERY_TIMEOUT,
10
+ BaseDatabaseConnection,
11
+ BaseSchemaIntrospector,
12
+ QueryTimeoutError,
13
+ )
14
+
15
+
16
+ def _execute_duckdb_transaction(
17
+ conn: duckdb.DuckDBPyConnection, query: str, args: tuple[Any, ...]
18
+ ) -> list[dict[str, Any]]:
19
+ """Run a DuckDB query inside a transaction and return list of dicts."""
20
+ conn.execute("BEGIN TRANSACTION")
21
+ try:
22
+ if args:
23
+ conn.execute(query, args)
24
+ else:
25
+ conn.execute(query)
26
+
27
+ if conn.description is None:
28
+ rows: list[dict[str, Any]] = []
29
+ else:
30
+ columns = [col[0] for col in conn.description]
31
+ data = conn.fetchall()
32
+ rows = [dict(zip(columns, row)) for row in data]
33
+
34
+ conn.execute("ROLLBACK")
35
+ return rows
36
+ except Exception:
37
+ conn.execute("ROLLBACK")
38
+ raise
39
+
40
+
41
+ class DuckDBConnection(BaseDatabaseConnection):
42
+ """DuckDB database connection using duckdb Python API."""
43
+
44
+ def __init__(self, connection_string: str):
45
+ super().__init__(connection_string)
46
+ if connection_string.startswith("duckdb:///"):
47
+ db_path = connection_string.replace("duckdb:///", "", 1)
48
+ elif connection_string.startswith("duckdb://"):
49
+ db_path = connection_string.replace("duckdb://", "", 1)
50
+ else:
51
+ db_path = connection_string
52
+
53
+ self.database_path = db_path or ":memory:"
54
+
55
+ async def get_pool(self):
56
+ """DuckDB creates connections per query, return database path."""
57
+ return self.database_path
58
+
59
+ async def close(self):
60
+ """DuckDB connections are created per query, no persistent pool to close."""
61
+ pass
62
+
63
+ async def execute_query(
64
+ self, query: str, *args, timeout: float | None = None
65
+ ) -> list[dict[str, Any]]:
66
+ """Execute a query and return results as list of dicts.
67
+
68
+ All queries run in a transaction that is rolled back at the end,
69
+ ensuring no changes are persisted to the database.
70
+ """
71
+ effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
72
+
73
+ args_tuple = tuple(args) if args else tuple()
74
+
75
+ def _run_query() -> list[dict[str, Any]]:
76
+ conn = duckdb.connect(self.database_path)
77
+ try:
78
+ return _execute_duckdb_transaction(conn, query, args_tuple)
79
+ finally:
80
+ conn.close()
81
+
82
+ try:
83
+ return await asyncio.wait_for(
84
+ asyncio.to_thread(_run_query), timeout=effective_timeout
85
+ )
86
+ except asyncio.TimeoutError as exc:
87
+ raise QueryTimeoutError(effective_timeout or 0) from exc
88
+
89
+
90
+ class DuckDBSchemaIntrospector(BaseSchemaIntrospector):
91
+ """DuckDB-specific schema introspection."""
92
+
93
+ async def _execute_query(
94
+ self,
95
+ connection,
96
+ query: str,
97
+ params: tuple[Any, ...] = (),
98
+ ) -> list[dict[str, Any]]:
99
+ """Run a DuckDB query on a thread and return list of dictionaries."""
100
+
101
+ params_tuple = tuple(params)
102
+
103
+ def fetch_rows(conn: duckdb.DuckDBPyConnection) -> list[dict[str, Any]]:
104
+ cursor = conn.execute(query, params_tuple)
105
+ if cursor.description is None:
106
+ return []
107
+
108
+ columns = [col[0] for col in cursor.description]
109
+ rows = conn.fetchall()
110
+ return [dict(zip(columns, row)) for row in rows]
111
+
112
+ # Handle CSV connections differently
113
+ if hasattr(connection, "execute_query") and hasattr(connection, "csv_path"):
114
+ return await connection.execute_query(query, *params_tuple)
115
+
116
+ def run_query() -> list[dict[str, Any]]:
117
+ conn = duckdb.connect(connection.database_path)
118
+ try:
119
+ return fetch_rows(conn)
120
+ finally:
121
+ conn.close()
122
+
123
+ return await asyncio.to_thread(run_query)
124
+
125
+ async def get_tables_info(
126
+ self, connection, table_pattern: str | None = None
127
+ ) -> list[dict[str, Any]]:
128
+ """Get tables information for DuckDB."""
129
+ where_conditions = [
130
+ "table_schema NOT IN ('information_schema', 'pg_catalog', 'duckdb_catalog')"
131
+ ]
132
+ params: list[Any] = []
133
+
134
+ if table_pattern:
135
+ if "." in table_pattern:
136
+ schema_pattern, table_name_pattern = table_pattern.split(".", 1)
137
+ where_conditions.append("(table_schema LIKE ? AND table_name LIKE ?)")
138
+ params.extend([schema_pattern, table_name_pattern])
139
+ else:
140
+ where_conditions.append(
141
+ "(table_name LIKE ? OR table_schema || '.' || table_name LIKE ?)"
142
+ )
143
+ params.extend([table_pattern, table_pattern])
144
+
145
+ query = f"""
146
+ SELECT
147
+ table_schema,
148
+ table_name,
149
+ table_type
150
+ FROM information_schema.tables
151
+ WHERE {" AND ".join(where_conditions)}
152
+ ORDER BY table_schema, table_name;
153
+ """
154
+
155
+ return await self._execute_query(connection, query, tuple(params))
156
+
157
+ async def get_columns_info(self, connection, tables: list) -> list[dict[str, Any]]:
158
+ """Get columns information for DuckDB."""
159
+ if not tables:
160
+ return []
161
+
162
+ table_filters = []
163
+ for table in tables:
164
+ table_filters.append("(table_schema = ? AND table_name = ?)")
165
+
166
+ params: list[Any] = []
167
+ for table in tables:
168
+ params.extend([table["table_schema"], table["table_name"]])
169
+
170
+ query = f"""
171
+ SELECT
172
+ table_schema,
173
+ table_name,
174
+ column_name,
175
+ data_type,
176
+ is_nullable,
177
+ column_default,
178
+ character_maximum_length,
179
+ numeric_precision,
180
+ numeric_scale
181
+ FROM information_schema.columns
182
+ WHERE {" OR ".join(table_filters)}
183
+ ORDER BY table_schema, table_name, ordinal_position;
184
+ """
185
+
186
+ return await self._execute_query(connection, query, tuple(params))
187
+
188
+ async def get_foreign_keys_info(
189
+ self, connection, tables: list
190
+ ) -> list[dict[str, Any]]:
191
+ """Get foreign keys information for DuckDB."""
192
+ if not tables:
193
+ return []
194
+
195
+ table_filters = []
196
+ params: list[Any] = []
197
+ for table in tables:
198
+ table_filters.append("(kcu.table_schema = ? AND kcu.table_name = ?)")
199
+ params.extend([table["table_schema"], table["table_name"]])
200
+
201
+ query = f"""
202
+ SELECT
203
+ kcu.table_schema,
204
+ kcu.table_name,
205
+ kcu.column_name,
206
+ ccu.table_schema AS foreign_table_schema,
207
+ ccu.table_name AS foreign_table_name,
208
+ ccu.column_name AS foreign_column_name
209
+ FROM information_schema.referential_constraints AS rc
210
+ JOIN information_schema.key_column_usage AS kcu
211
+ ON rc.constraint_schema = kcu.constraint_schema
212
+ AND rc.constraint_name = kcu.constraint_name
213
+ JOIN information_schema.key_column_usage AS ccu
214
+ ON rc.unique_constraint_schema = ccu.constraint_schema
215
+ AND rc.unique_constraint_name = ccu.constraint_name
216
+ AND ccu.ordinal_position = kcu.position_in_unique_constraint
217
+ WHERE {" OR ".join(table_filters)}
218
+ ORDER BY kcu.table_schema, kcu.table_name, kcu.ordinal_position;
219
+ """
220
+
221
+ return await self._execute_query(connection, query, tuple(params))
222
+
223
+ async def get_primary_keys_info(
224
+ self, connection, tables: list
225
+ ) -> list[dict[str, Any]]:
226
+ """Get primary keys information for DuckDB."""
227
+ if not tables:
228
+ return []
229
+
230
+ table_filters = []
231
+ params: list[Any] = []
232
+ for table in tables:
233
+ table_filters.append("(tc.table_schema = ? AND tc.table_name = ?)")
234
+ params.extend([table["table_schema"], table["table_name"]])
235
+
236
+ query = f"""
237
+ SELECT
238
+ tc.table_schema,
239
+ tc.table_name,
240
+ kcu.column_name
241
+ FROM information_schema.table_constraints AS tc
242
+ JOIN information_schema.key_column_usage AS kcu
243
+ ON tc.constraint_name = kcu.constraint_name
244
+ AND tc.constraint_schema = kcu.constraint_schema
245
+ WHERE tc.constraint_type = 'PRIMARY KEY'
246
+ AND ({" OR ".join(table_filters)})
247
+ ORDER BY tc.table_schema, tc.table_name, kcu.ordinal_position;
248
+ """
249
+
250
+ return await self._execute_query(connection, query, tuple(params))
251
+
252
+ async def get_indexes_info(self, connection, tables: list) -> list[dict[str, Any]]:
253
+ """Get indexes information for DuckDB."""
254
+ if not tables:
255
+ return []
256
+
257
+ indexes: list[dict[str, Any]] = []
258
+ for table in tables:
259
+ schema = table["table_schema"]
260
+ table_name = table["table_name"]
261
+ query = """
262
+ SELECT
263
+ schema_name,
264
+ table_name,
265
+ index_name,
266
+ sql
267
+ FROM duckdb_indexes()
268
+ WHERE schema_name = ? AND table_name = ?;
269
+ """
270
+ rows = await self._execute_query(connection, query, (schema, table_name))
271
+
272
+ for row in rows:
273
+ sql_text = (row.get("sql") or "").strip()
274
+ upper_sql = sql_text.upper()
275
+ unique = "UNIQUE" in upper_sql.split("(")[0]
276
+
277
+ columns: list[str] = []
278
+ if "(" in sql_text and ")" in sql_text:
279
+ column_section = sql_text[
280
+ sql_text.find("(") + 1 : sql_text.rfind(")")
281
+ ]
282
+ columns = [
283
+ col.strip().strip('"')
284
+ for col in column_section.split(",")
285
+ if col.strip()
286
+ ]
287
+
288
+ indexes.append(
289
+ {
290
+ "table_schema": row.get("schema_name") or schema or "main",
291
+ "table_name": row.get("table_name") or table_name,
292
+ "index_name": row.get("index_name"),
293
+ "is_unique": unique,
294
+ "index_type": None,
295
+ "column_names": columns,
296
+ }
297
+ )
298
+
299
+ return indexes
300
+
301
+ async def list_tables_info(self, connection) -> list[dict[str, Any]]:
302
+ """Get list of tables with basic information for DuckDB."""
303
+ query = """
304
+ SELECT
305
+ table_schema,
306
+ table_name,
307
+ table_type
308
+ FROM information_schema.tables
309
+ WHERE table_schema NOT IN ('information_schema', 'pg_catalog', 'duckdb_catalog')
310
+ ORDER BY table_schema, table_name;
311
+ """
312
+
313
+ return await self._execute_query(connection, query)