sqlsaber 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -2
- sqlsaber/agents/base.py +1 -1
- sqlsaber/agents/mcp.py +1 -1
- sqlsaber/agents/pydantic_ai_agent.py +207 -135
- sqlsaber/cli/commands.py +11 -28
- sqlsaber/cli/completers.py +2 -0
- sqlsaber/cli/database.py +1 -1
- sqlsaber/cli/display.py +29 -9
- sqlsaber/cli/interactive.py +22 -15
- sqlsaber/cli/streaming.py +15 -17
- sqlsaber/cli/threads.py +10 -6
- sqlsaber/config/settings.py +25 -2
- sqlsaber/database/__init__.py +55 -1
- sqlsaber/database/base.py +124 -0
- sqlsaber/database/csv.py +133 -0
- sqlsaber/database/duckdb.py +313 -0
- sqlsaber/database/mysql.py +345 -0
- sqlsaber/database/postgresql.py +328 -0
- sqlsaber/database/schema.py +66 -963
- sqlsaber/database/sqlite.py +258 -0
- sqlsaber/mcp/mcp.py +1 -1
- sqlsaber/tools/sql_tools.py +1 -1
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/METADATA +43 -9
- sqlsaber-0.26.0.dist-info/RECORD +52 -0
- sqlsaber/database/connection.py +0 -535
- sqlsaber-0.25.0.dist-info/RECORD +0 -47
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
"""PostgreSQL database connection and schema introspection."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import ssl
|
|
5
|
+
from typing import Any
|
|
6
|
+
from urllib.parse import parse_qs, urlparse
|
|
7
|
+
|
|
8
|
+
import asyncpg
|
|
9
|
+
|
|
10
|
+
from .base import (
|
|
11
|
+
DEFAULT_QUERY_TIMEOUT,
|
|
12
|
+
BaseDatabaseConnection,
|
|
13
|
+
BaseSchemaIntrospector,
|
|
14
|
+
QueryTimeoutError,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PostgreSQLConnection(BaseDatabaseConnection):
|
|
19
|
+
"""PostgreSQL database connection using asyncpg."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, connection_string: str):
|
|
22
|
+
super().__init__(connection_string)
|
|
23
|
+
self._pool: asyncpg.Pool | None = None
|
|
24
|
+
self._ssl_context = self._create_ssl_context()
|
|
25
|
+
|
|
26
|
+
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
|
27
|
+
"""Create SSL context from connection string parameters."""
|
|
28
|
+
parsed = urlparse(self.connection_string)
|
|
29
|
+
if not parsed.query:
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
params = parse_qs(parsed.query)
|
|
33
|
+
ssl_mode = params.get("sslmode", [None])[0]
|
|
34
|
+
|
|
35
|
+
if not ssl_mode or ssl_mode == "disable":
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
# Create SSL context based on mode
|
|
39
|
+
if ssl_mode in ["require", "verify-ca", "verify-full"]:
|
|
40
|
+
ssl_context = ssl.create_default_context()
|
|
41
|
+
|
|
42
|
+
# Configure certificate verification
|
|
43
|
+
if ssl_mode == "require":
|
|
44
|
+
ssl_context.check_hostname = False
|
|
45
|
+
ssl_context.verify_mode = ssl.CERT_NONE
|
|
46
|
+
elif ssl_mode == "verify-ca":
|
|
47
|
+
ssl_context.check_hostname = False
|
|
48
|
+
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
|
49
|
+
elif ssl_mode == "verify-full":
|
|
50
|
+
ssl_context.check_hostname = True
|
|
51
|
+
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
|
52
|
+
|
|
53
|
+
# Load certificates if provided
|
|
54
|
+
ssl_ca = params.get("sslrootcert", [None])[0]
|
|
55
|
+
ssl_cert = params.get("sslcert", [None])[0]
|
|
56
|
+
ssl_key = params.get("sslkey", [None])[0]
|
|
57
|
+
|
|
58
|
+
if ssl_ca:
|
|
59
|
+
ssl_context.load_verify_locations(ssl_ca)
|
|
60
|
+
|
|
61
|
+
if ssl_cert and ssl_key:
|
|
62
|
+
ssl_context.load_cert_chain(ssl_cert, ssl_key)
|
|
63
|
+
|
|
64
|
+
return ssl_context
|
|
65
|
+
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
async def get_pool(self) -> asyncpg.Pool:
|
|
69
|
+
"""Get or create connection pool."""
|
|
70
|
+
if self._pool is None:
|
|
71
|
+
# Create pool with SSL context if configured
|
|
72
|
+
if self._ssl_context:
|
|
73
|
+
self._pool = await asyncpg.create_pool(
|
|
74
|
+
self.connection_string,
|
|
75
|
+
min_size=1,
|
|
76
|
+
max_size=10,
|
|
77
|
+
ssl=self._ssl_context,
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
self._pool = await asyncpg.create_pool(
|
|
81
|
+
self.connection_string, min_size=1, max_size=10
|
|
82
|
+
)
|
|
83
|
+
return self._pool
|
|
84
|
+
|
|
85
|
+
async def close(self):
|
|
86
|
+
"""Close the connection pool."""
|
|
87
|
+
if self._pool:
|
|
88
|
+
await self._pool.close()
|
|
89
|
+
self._pool = None
|
|
90
|
+
|
|
91
|
+
async def execute_query(
|
|
92
|
+
self, query: str, *args, timeout: float | None = None
|
|
93
|
+
) -> list[dict[str, Any]]:
|
|
94
|
+
"""Execute a query and return results as list of dicts.
|
|
95
|
+
|
|
96
|
+
All queries run in a transaction that is rolled back at the end,
|
|
97
|
+
ensuring no changes are persisted to the database.
|
|
98
|
+
"""
|
|
99
|
+
effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
|
|
100
|
+
pool = await self.get_pool()
|
|
101
|
+
|
|
102
|
+
async with pool.acquire() as conn:
|
|
103
|
+
# Start a transaction that we'll always rollback
|
|
104
|
+
transaction = conn.transaction()
|
|
105
|
+
await transaction.start()
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
# Set server-side timeout if specified
|
|
109
|
+
if effective_timeout:
|
|
110
|
+
# Clamp timeout to sane range (10ms to 5 minutes) and validate
|
|
111
|
+
timeout_ms = max(10, min(int(effective_timeout * 1000), 300000))
|
|
112
|
+
await conn.execute(f"SET LOCAL statement_timeout = {timeout_ms}")
|
|
113
|
+
|
|
114
|
+
# Execute query with client-side timeout
|
|
115
|
+
if effective_timeout:
|
|
116
|
+
rows = await asyncio.wait_for(
|
|
117
|
+
conn.fetch(query, *args), timeout=effective_timeout
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
rows = await conn.fetch(query, *args)
|
|
121
|
+
|
|
122
|
+
return [dict(row) for row in rows]
|
|
123
|
+
except asyncio.TimeoutError as exc:
|
|
124
|
+
raise QueryTimeoutError(effective_timeout or 0) from exc
|
|
125
|
+
finally:
|
|
126
|
+
# Always rollback to ensure no changes are committed
|
|
127
|
+
await transaction.rollback()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class PostgreSQLSchemaIntrospector(BaseSchemaIntrospector):
|
|
131
|
+
"""PostgreSQL-specific schema introspection."""
|
|
132
|
+
|
|
133
|
+
def _build_table_filter_clause(self, tables: list) -> tuple[str, list]:
|
|
134
|
+
"""Build VALUES clause with bind parameters for table filtering.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
tables: List of table dictionaries with table_schema and table_name keys
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Tuple of (values_clause, params) for use in SQL queries
|
|
141
|
+
"""
|
|
142
|
+
if not tables:
|
|
143
|
+
return "", []
|
|
144
|
+
|
|
145
|
+
table_pairs = [(table["table_schema"], table["table_name"]) for table in tables]
|
|
146
|
+
values_clause = ", ".join(
|
|
147
|
+
[f"(${2 * i + 1}, ${2 * i + 2})" for i in range(len(table_pairs))]
|
|
148
|
+
)
|
|
149
|
+
params = [value for pair in table_pairs for value in pair]
|
|
150
|
+
return values_clause, params
|
|
151
|
+
|
|
152
|
+
async def get_tables_info(
|
|
153
|
+
self, connection, table_pattern: str | None = None
|
|
154
|
+
) -> dict[str, Any]:
|
|
155
|
+
"""Get tables information for PostgreSQL."""
|
|
156
|
+
pool = await connection.get_pool()
|
|
157
|
+
async with pool.acquire() as conn:
|
|
158
|
+
# Build WHERE clause for filtering
|
|
159
|
+
where_conditions = [
|
|
160
|
+
"table_schema NOT IN ('pg_catalog', 'information_schema')"
|
|
161
|
+
]
|
|
162
|
+
params = []
|
|
163
|
+
|
|
164
|
+
if table_pattern:
|
|
165
|
+
# Support patterns like 'schema.table' or just 'table'
|
|
166
|
+
if "." in table_pattern:
|
|
167
|
+
schema_pattern, table_name_pattern = table_pattern.split(".", 1)
|
|
168
|
+
where_conditions.append(
|
|
169
|
+
"(table_schema LIKE $1 AND table_name LIKE $2)"
|
|
170
|
+
)
|
|
171
|
+
params.extend([schema_pattern, table_name_pattern])
|
|
172
|
+
else:
|
|
173
|
+
where_conditions.append(
|
|
174
|
+
"(table_name LIKE $1 OR table_schema || '.' || table_name LIKE $1)"
|
|
175
|
+
)
|
|
176
|
+
params.append(table_pattern)
|
|
177
|
+
|
|
178
|
+
# Get tables
|
|
179
|
+
tables_query = f"""
|
|
180
|
+
SELECT
|
|
181
|
+
table_schema,
|
|
182
|
+
table_name,
|
|
183
|
+
table_type
|
|
184
|
+
FROM information_schema.tables
|
|
185
|
+
WHERE {" AND ".join(where_conditions)}
|
|
186
|
+
ORDER BY table_schema, table_name;
|
|
187
|
+
"""
|
|
188
|
+
return await conn.fetch(tables_query, *params)
|
|
189
|
+
|
|
190
|
+
async def get_columns_info(self, connection, tables: list) -> list:
|
|
191
|
+
"""Get columns information for PostgreSQL."""
|
|
192
|
+
if not tables:
|
|
193
|
+
return []
|
|
194
|
+
|
|
195
|
+
pool = await connection.get_pool()
|
|
196
|
+
async with pool.acquire() as conn:
|
|
197
|
+
values_clause, params = self._build_table_filter_clause(tables)
|
|
198
|
+
|
|
199
|
+
columns_query = f"""
|
|
200
|
+
SELECT
|
|
201
|
+
c.table_schema,
|
|
202
|
+
c.table_name,
|
|
203
|
+
c.column_name,
|
|
204
|
+
c.data_type,
|
|
205
|
+
c.is_nullable,
|
|
206
|
+
c.column_default,
|
|
207
|
+
c.character_maximum_length,
|
|
208
|
+
c.numeric_precision,
|
|
209
|
+
c.numeric_scale
|
|
210
|
+
FROM information_schema.columns c
|
|
211
|
+
WHERE (c.table_schema, c.table_name) IN (VALUES {values_clause})
|
|
212
|
+
ORDER BY c.table_schema, c.table_name, c.ordinal_position;
|
|
213
|
+
"""
|
|
214
|
+
return await conn.fetch(columns_query, *params)
|
|
215
|
+
|
|
216
|
+
async def get_foreign_keys_info(self, connection, tables: list) -> list:
|
|
217
|
+
"""Get foreign keys information for PostgreSQL."""
|
|
218
|
+
if not tables:
|
|
219
|
+
return []
|
|
220
|
+
|
|
221
|
+
pool = await connection.get_pool()
|
|
222
|
+
async with pool.acquire() as conn:
|
|
223
|
+
values_clause, params = self._build_table_filter_clause(tables)
|
|
224
|
+
|
|
225
|
+
fk_query = f"""
|
|
226
|
+
WITH t(schema, name) AS (VALUES {values_clause})
|
|
227
|
+
SELECT
|
|
228
|
+
tc.table_schema,
|
|
229
|
+
tc.table_name,
|
|
230
|
+
kcu.column_name,
|
|
231
|
+
ccu.table_schema AS foreign_table_schema,
|
|
232
|
+
ccu.table_name AS foreign_table_name,
|
|
233
|
+
ccu.column_name AS foreign_column_name
|
|
234
|
+
FROM information_schema.table_constraints AS tc
|
|
235
|
+
JOIN information_schema.key_column_usage AS kcu
|
|
236
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
237
|
+
AND tc.table_schema = kcu.table_schema
|
|
238
|
+
JOIN information_schema.constraint_column_usage AS ccu
|
|
239
|
+
ON ccu.constraint_name = tc.constraint_name
|
|
240
|
+
AND ccu.table_schema = tc.table_schema
|
|
241
|
+
JOIN t ON t.schema = tc.table_schema AND t.name = tc.table_name
|
|
242
|
+
WHERE tc.constraint_type = 'FOREIGN KEY';
|
|
243
|
+
"""
|
|
244
|
+
return await conn.fetch(fk_query, *params)
|
|
245
|
+
|
|
246
|
+
async def get_primary_keys_info(self, connection, tables: list) -> list:
|
|
247
|
+
"""Get primary keys information for PostgreSQL."""
|
|
248
|
+
if not tables:
|
|
249
|
+
return []
|
|
250
|
+
|
|
251
|
+
pool = await connection.get_pool()
|
|
252
|
+
async with pool.acquire() as conn:
|
|
253
|
+
values_clause, params = self._build_table_filter_clause(tables)
|
|
254
|
+
|
|
255
|
+
pk_query = f"""
|
|
256
|
+
WITH t(schema, name) AS (VALUES {values_clause})
|
|
257
|
+
SELECT
|
|
258
|
+
tc.table_schema,
|
|
259
|
+
tc.table_name,
|
|
260
|
+
kcu.column_name
|
|
261
|
+
FROM information_schema.table_constraints AS tc
|
|
262
|
+
JOIN information_schema.key_column_usage AS kcu
|
|
263
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
264
|
+
AND tc.table_schema = kcu.table_schema
|
|
265
|
+
JOIN t ON t.schema = tc.table_schema AND t.name = tc.table_name
|
|
266
|
+
WHERE tc.constraint_type = 'PRIMARY KEY'
|
|
267
|
+
ORDER BY tc.table_schema, tc.table_name, kcu.ordinal_position;
|
|
268
|
+
"""
|
|
269
|
+
return await conn.fetch(pk_query, *params)
|
|
270
|
+
|
|
271
|
+
async def get_indexes_info(self, connection, tables: list) -> list:
|
|
272
|
+
"""Get indexes information for PostgreSQL."""
|
|
273
|
+
if not tables:
|
|
274
|
+
return []
|
|
275
|
+
|
|
276
|
+
pool = await connection.get_pool()
|
|
277
|
+
async with pool.acquire() as conn:
|
|
278
|
+
values_clause, params = self._build_table_filter_clause(tables)
|
|
279
|
+
|
|
280
|
+
idx_query = f"""
|
|
281
|
+
WITH t_filter(schema, name) AS (VALUES {values_clause})
|
|
282
|
+
SELECT
|
|
283
|
+
ns.nspname AS table_schema,
|
|
284
|
+
tcls.relname AS table_name,
|
|
285
|
+
icls.relname AS index_name,
|
|
286
|
+
ix.indisunique AS is_unique,
|
|
287
|
+
am.amname AS index_type,
|
|
288
|
+
string_agg(a.attname, ',' ORDER BY att.ordinality) AS column_names
|
|
289
|
+
FROM pg_class tcls
|
|
290
|
+
JOIN pg_namespace ns ON tcls.relnamespace = ns.oid
|
|
291
|
+
JOIN pg_index ix ON tcls.oid = ix.indrelid
|
|
292
|
+
JOIN pg_class icls ON icls.oid = ix.indexrelid
|
|
293
|
+
JOIN pg_am am ON icls.relam = am.oid
|
|
294
|
+
JOIN pg_attribute a ON a.attrelid = tcls.oid
|
|
295
|
+
JOIN unnest(ix.indkey) WITH ORDINALITY AS att(attnum, ordinality) ON a.attnum = att.attnum
|
|
296
|
+
JOIN t_filter ON t_filter.schema = ns.nspname AND t_filter.name = tcls.relname
|
|
297
|
+
WHERE tcls.relkind = 'r'
|
|
298
|
+
AND icls.relname NOT LIKE '%_pkey'
|
|
299
|
+
GROUP BY ns.nspname, tcls.relname, icls.relname, ix.indisunique, am.amname
|
|
300
|
+
ORDER BY ns.nspname, tcls.relname, icls.relname;
|
|
301
|
+
"""
|
|
302
|
+
return await conn.fetch(idx_query, *params)
|
|
303
|
+
|
|
304
|
+
async def list_tables_info(self, connection) -> list[dict[str, Any]]:
|
|
305
|
+
"""Get list of tables with basic information for PostgreSQL."""
|
|
306
|
+
pool = await connection.get_pool()
|
|
307
|
+
async with pool.acquire() as conn:
|
|
308
|
+
# Get table names and basic info without row counts for better performance
|
|
309
|
+
tables_query = """
|
|
310
|
+
SELECT
|
|
311
|
+
table_schema,
|
|
312
|
+
table_name,
|
|
313
|
+
table_type
|
|
314
|
+
FROM information_schema.tables
|
|
315
|
+
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
|
|
316
|
+
ORDER BY table_schema, table_name;
|
|
317
|
+
"""
|
|
318
|
+
tables = await conn.fetch(tables_query)
|
|
319
|
+
|
|
320
|
+
# Convert to expected format
|
|
321
|
+
return [
|
|
322
|
+
{
|
|
323
|
+
"table_schema": table["table_schema"],
|
|
324
|
+
"table_name": table["table_name"],
|
|
325
|
+
"table_type": table["table_type"],
|
|
326
|
+
}
|
|
327
|
+
for table in tables
|
|
328
|
+
]
|