sqlsaber 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -2
- sqlsaber/agents/base.py +1 -1
- sqlsaber/agents/mcp.py +1 -1
- sqlsaber/agents/pydantic_ai_agent.py +207 -135
- sqlsaber/cli/commands.py +11 -28
- sqlsaber/cli/completers.py +2 -0
- sqlsaber/cli/database.py +1 -1
- sqlsaber/cli/display.py +29 -9
- sqlsaber/cli/interactive.py +22 -15
- sqlsaber/cli/streaming.py +15 -17
- sqlsaber/cli/threads.py +10 -6
- sqlsaber/config/settings.py +25 -2
- sqlsaber/database/__init__.py +55 -1
- sqlsaber/database/base.py +124 -0
- sqlsaber/database/csv.py +133 -0
- sqlsaber/database/duckdb.py +313 -0
- sqlsaber/database/mysql.py +345 -0
- sqlsaber/database/postgresql.py +328 -0
- sqlsaber/database/schema.py +66 -963
- sqlsaber/database/sqlite.py +258 -0
- sqlsaber/mcp/mcp.py +1 -1
- sqlsaber/tools/sql_tools.py +1 -1
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/METADATA +43 -9
- sqlsaber-0.26.0.dist-info/RECORD +52 -0
- sqlsaber/database/connection.py +0 -535
- sqlsaber-0.25.0.dist-info/RECORD +0 -47
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.26.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
"""DuckDB database connection and schema introspection."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import duckdb
|
|
7
|
+
|
|
8
|
+
from .base import (
|
|
9
|
+
DEFAULT_QUERY_TIMEOUT,
|
|
10
|
+
BaseDatabaseConnection,
|
|
11
|
+
BaseSchemaIntrospector,
|
|
12
|
+
QueryTimeoutError,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _execute_duckdb_transaction(
|
|
17
|
+
conn: duckdb.DuckDBPyConnection, query: str, args: tuple[Any, ...]
|
|
18
|
+
) -> list[dict[str, Any]]:
|
|
19
|
+
"""Run a DuckDB query inside a transaction and return list of dicts."""
|
|
20
|
+
conn.execute("BEGIN TRANSACTION")
|
|
21
|
+
try:
|
|
22
|
+
if args:
|
|
23
|
+
conn.execute(query, args)
|
|
24
|
+
else:
|
|
25
|
+
conn.execute(query)
|
|
26
|
+
|
|
27
|
+
if conn.description is None:
|
|
28
|
+
rows: list[dict[str, Any]] = []
|
|
29
|
+
else:
|
|
30
|
+
columns = [col[0] for col in conn.description]
|
|
31
|
+
data = conn.fetchall()
|
|
32
|
+
rows = [dict(zip(columns, row)) for row in data]
|
|
33
|
+
|
|
34
|
+
conn.execute("ROLLBACK")
|
|
35
|
+
return rows
|
|
36
|
+
except Exception:
|
|
37
|
+
conn.execute("ROLLBACK")
|
|
38
|
+
raise
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DuckDBConnection(BaseDatabaseConnection):
|
|
42
|
+
"""DuckDB database connection using duckdb Python API."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, connection_string: str):
|
|
45
|
+
super().__init__(connection_string)
|
|
46
|
+
if connection_string.startswith("duckdb:///"):
|
|
47
|
+
db_path = connection_string.replace("duckdb:///", "", 1)
|
|
48
|
+
elif connection_string.startswith("duckdb://"):
|
|
49
|
+
db_path = connection_string.replace("duckdb://", "", 1)
|
|
50
|
+
else:
|
|
51
|
+
db_path = connection_string
|
|
52
|
+
|
|
53
|
+
self.database_path = db_path or ":memory:"
|
|
54
|
+
|
|
55
|
+
async def get_pool(self):
|
|
56
|
+
"""DuckDB creates connections per query, return database path."""
|
|
57
|
+
return self.database_path
|
|
58
|
+
|
|
59
|
+
async def close(self):
|
|
60
|
+
"""DuckDB connections are created per query, no persistent pool to close."""
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
async def execute_query(
|
|
64
|
+
self, query: str, *args, timeout: float | None = None
|
|
65
|
+
) -> list[dict[str, Any]]:
|
|
66
|
+
"""Execute a query and return results as list of dicts.
|
|
67
|
+
|
|
68
|
+
All queries run in a transaction that is rolled back at the end,
|
|
69
|
+
ensuring no changes are persisted to the database.
|
|
70
|
+
"""
|
|
71
|
+
effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
|
|
72
|
+
|
|
73
|
+
args_tuple = tuple(args) if args else tuple()
|
|
74
|
+
|
|
75
|
+
def _run_query() -> list[dict[str, Any]]:
|
|
76
|
+
conn = duckdb.connect(self.database_path)
|
|
77
|
+
try:
|
|
78
|
+
return _execute_duckdb_transaction(conn, query, args_tuple)
|
|
79
|
+
finally:
|
|
80
|
+
conn.close()
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
return await asyncio.wait_for(
|
|
84
|
+
asyncio.to_thread(_run_query), timeout=effective_timeout
|
|
85
|
+
)
|
|
86
|
+
except asyncio.TimeoutError as exc:
|
|
87
|
+
raise QueryTimeoutError(effective_timeout or 0) from exc
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class DuckDBSchemaIntrospector(BaseSchemaIntrospector):
|
|
91
|
+
"""DuckDB-specific schema introspection."""
|
|
92
|
+
|
|
93
|
+
async def _execute_query(
|
|
94
|
+
self,
|
|
95
|
+
connection,
|
|
96
|
+
query: str,
|
|
97
|
+
params: tuple[Any, ...] = (),
|
|
98
|
+
) -> list[dict[str, Any]]:
|
|
99
|
+
"""Run a DuckDB query on a thread and return list of dictionaries."""
|
|
100
|
+
|
|
101
|
+
params_tuple = tuple(params)
|
|
102
|
+
|
|
103
|
+
def fetch_rows(conn: duckdb.DuckDBPyConnection) -> list[dict[str, Any]]:
|
|
104
|
+
cursor = conn.execute(query, params_tuple)
|
|
105
|
+
if cursor.description is None:
|
|
106
|
+
return []
|
|
107
|
+
|
|
108
|
+
columns = [col[0] for col in cursor.description]
|
|
109
|
+
rows = conn.fetchall()
|
|
110
|
+
return [dict(zip(columns, row)) for row in rows]
|
|
111
|
+
|
|
112
|
+
# Handle CSV connections differently
|
|
113
|
+
if hasattr(connection, "execute_query") and hasattr(connection, "csv_path"):
|
|
114
|
+
return await connection.execute_query(query, *params_tuple)
|
|
115
|
+
|
|
116
|
+
def run_query() -> list[dict[str, Any]]:
|
|
117
|
+
conn = duckdb.connect(connection.database_path)
|
|
118
|
+
try:
|
|
119
|
+
return fetch_rows(conn)
|
|
120
|
+
finally:
|
|
121
|
+
conn.close()
|
|
122
|
+
|
|
123
|
+
return await asyncio.to_thread(run_query)
|
|
124
|
+
|
|
125
|
+
async def get_tables_info(
|
|
126
|
+
self, connection, table_pattern: str | None = None
|
|
127
|
+
) -> list[dict[str, Any]]:
|
|
128
|
+
"""Get tables information for DuckDB."""
|
|
129
|
+
where_conditions = [
|
|
130
|
+
"table_schema NOT IN ('information_schema', 'pg_catalog', 'duckdb_catalog')"
|
|
131
|
+
]
|
|
132
|
+
params: list[Any] = []
|
|
133
|
+
|
|
134
|
+
if table_pattern:
|
|
135
|
+
if "." in table_pattern:
|
|
136
|
+
schema_pattern, table_name_pattern = table_pattern.split(".", 1)
|
|
137
|
+
where_conditions.append("(table_schema LIKE ? AND table_name LIKE ?)")
|
|
138
|
+
params.extend([schema_pattern, table_name_pattern])
|
|
139
|
+
else:
|
|
140
|
+
where_conditions.append(
|
|
141
|
+
"(table_name LIKE ? OR table_schema || '.' || table_name LIKE ?)"
|
|
142
|
+
)
|
|
143
|
+
params.extend([table_pattern, table_pattern])
|
|
144
|
+
|
|
145
|
+
query = f"""
|
|
146
|
+
SELECT
|
|
147
|
+
table_schema,
|
|
148
|
+
table_name,
|
|
149
|
+
table_type
|
|
150
|
+
FROM information_schema.tables
|
|
151
|
+
WHERE {" AND ".join(where_conditions)}
|
|
152
|
+
ORDER BY table_schema, table_name;
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
return await self._execute_query(connection, query, tuple(params))
|
|
156
|
+
|
|
157
|
+
async def get_columns_info(self, connection, tables: list) -> list[dict[str, Any]]:
|
|
158
|
+
"""Get columns information for DuckDB."""
|
|
159
|
+
if not tables:
|
|
160
|
+
return []
|
|
161
|
+
|
|
162
|
+
table_filters = []
|
|
163
|
+
for table in tables:
|
|
164
|
+
table_filters.append("(table_schema = ? AND table_name = ?)")
|
|
165
|
+
|
|
166
|
+
params: list[Any] = []
|
|
167
|
+
for table in tables:
|
|
168
|
+
params.extend([table["table_schema"], table["table_name"]])
|
|
169
|
+
|
|
170
|
+
query = f"""
|
|
171
|
+
SELECT
|
|
172
|
+
table_schema,
|
|
173
|
+
table_name,
|
|
174
|
+
column_name,
|
|
175
|
+
data_type,
|
|
176
|
+
is_nullable,
|
|
177
|
+
column_default,
|
|
178
|
+
character_maximum_length,
|
|
179
|
+
numeric_precision,
|
|
180
|
+
numeric_scale
|
|
181
|
+
FROM information_schema.columns
|
|
182
|
+
WHERE {" OR ".join(table_filters)}
|
|
183
|
+
ORDER BY table_schema, table_name, ordinal_position;
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
return await self._execute_query(connection, query, tuple(params))
|
|
187
|
+
|
|
188
|
+
async def get_foreign_keys_info(
|
|
189
|
+
self, connection, tables: list
|
|
190
|
+
) -> list[dict[str, Any]]:
|
|
191
|
+
"""Get foreign keys information for DuckDB."""
|
|
192
|
+
if not tables:
|
|
193
|
+
return []
|
|
194
|
+
|
|
195
|
+
table_filters = []
|
|
196
|
+
params: list[Any] = []
|
|
197
|
+
for table in tables:
|
|
198
|
+
table_filters.append("(kcu.table_schema = ? AND kcu.table_name = ?)")
|
|
199
|
+
params.extend([table["table_schema"], table["table_name"]])
|
|
200
|
+
|
|
201
|
+
query = f"""
|
|
202
|
+
SELECT
|
|
203
|
+
kcu.table_schema,
|
|
204
|
+
kcu.table_name,
|
|
205
|
+
kcu.column_name,
|
|
206
|
+
ccu.table_schema AS foreign_table_schema,
|
|
207
|
+
ccu.table_name AS foreign_table_name,
|
|
208
|
+
ccu.column_name AS foreign_column_name
|
|
209
|
+
FROM information_schema.referential_constraints AS rc
|
|
210
|
+
JOIN information_schema.key_column_usage AS kcu
|
|
211
|
+
ON rc.constraint_schema = kcu.constraint_schema
|
|
212
|
+
AND rc.constraint_name = kcu.constraint_name
|
|
213
|
+
JOIN information_schema.key_column_usage AS ccu
|
|
214
|
+
ON rc.unique_constraint_schema = ccu.constraint_schema
|
|
215
|
+
AND rc.unique_constraint_name = ccu.constraint_name
|
|
216
|
+
AND ccu.ordinal_position = kcu.position_in_unique_constraint
|
|
217
|
+
WHERE {" OR ".join(table_filters)}
|
|
218
|
+
ORDER BY kcu.table_schema, kcu.table_name, kcu.ordinal_position;
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
return await self._execute_query(connection, query, tuple(params))
|
|
222
|
+
|
|
223
|
+
async def get_primary_keys_info(
|
|
224
|
+
self, connection, tables: list
|
|
225
|
+
) -> list[dict[str, Any]]:
|
|
226
|
+
"""Get primary keys information for DuckDB."""
|
|
227
|
+
if not tables:
|
|
228
|
+
return []
|
|
229
|
+
|
|
230
|
+
table_filters = []
|
|
231
|
+
params: list[Any] = []
|
|
232
|
+
for table in tables:
|
|
233
|
+
table_filters.append("(tc.table_schema = ? AND tc.table_name = ?)")
|
|
234
|
+
params.extend([table["table_schema"], table["table_name"]])
|
|
235
|
+
|
|
236
|
+
query = f"""
|
|
237
|
+
SELECT
|
|
238
|
+
tc.table_schema,
|
|
239
|
+
tc.table_name,
|
|
240
|
+
kcu.column_name
|
|
241
|
+
FROM information_schema.table_constraints AS tc
|
|
242
|
+
JOIN information_schema.key_column_usage AS kcu
|
|
243
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
244
|
+
AND tc.constraint_schema = kcu.constraint_schema
|
|
245
|
+
WHERE tc.constraint_type = 'PRIMARY KEY'
|
|
246
|
+
AND ({" OR ".join(table_filters)})
|
|
247
|
+
ORDER BY tc.table_schema, tc.table_name, kcu.ordinal_position;
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
return await self._execute_query(connection, query, tuple(params))
|
|
251
|
+
|
|
252
|
+
async def get_indexes_info(self, connection, tables: list) -> list[dict[str, Any]]:
|
|
253
|
+
"""Get indexes information for DuckDB."""
|
|
254
|
+
if not tables:
|
|
255
|
+
return []
|
|
256
|
+
|
|
257
|
+
indexes: list[dict[str, Any]] = []
|
|
258
|
+
for table in tables:
|
|
259
|
+
schema = table["table_schema"]
|
|
260
|
+
table_name = table["table_name"]
|
|
261
|
+
query = """
|
|
262
|
+
SELECT
|
|
263
|
+
schema_name,
|
|
264
|
+
table_name,
|
|
265
|
+
index_name,
|
|
266
|
+
sql
|
|
267
|
+
FROM duckdb_indexes()
|
|
268
|
+
WHERE schema_name = ? AND table_name = ?;
|
|
269
|
+
"""
|
|
270
|
+
rows = await self._execute_query(connection, query, (schema, table_name))
|
|
271
|
+
|
|
272
|
+
for row in rows:
|
|
273
|
+
sql_text = (row.get("sql") or "").strip()
|
|
274
|
+
upper_sql = sql_text.upper()
|
|
275
|
+
unique = "UNIQUE" in upper_sql.split("(")[0]
|
|
276
|
+
|
|
277
|
+
columns: list[str] = []
|
|
278
|
+
if "(" in sql_text and ")" in sql_text:
|
|
279
|
+
column_section = sql_text[
|
|
280
|
+
sql_text.find("(") + 1 : sql_text.rfind(")")
|
|
281
|
+
]
|
|
282
|
+
columns = [
|
|
283
|
+
col.strip().strip('"')
|
|
284
|
+
for col in column_section.split(",")
|
|
285
|
+
if col.strip()
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
indexes.append(
|
|
289
|
+
{
|
|
290
|
+
"table_schema": row.get("schema_name") or schema or "main",
|
|
291
|
+
"table_name": row.get("table_name") or table_name,
|
|
292
|
+
"index_name": row.get("index_name"),
|
|
293
|
+
"is_unique": unique,
|
|
294
|
+
"index_type": None,
|
|
295
|
+
"column_names": columns,
|
|
296
|
+
}
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
return indexes
|
|
300
|
+
|
|
301
|
+
async def list_tables_info(self, connection) -> list[dict[str, Any]]:
|
|
302
|
+
"""Get list of tables with basic information for DuckDB."""
|
|
303
|
+
query = """
|
|
304
|
+
SELECT
|
|
305
|
+
table_schema,
|
|
306
|
+
table_name,
|
|
307
|
+
table_type
|
|
308
|
+
FROM information_schema.tables
|
|
309
|
+
WHERE table_schema NOT IN ('information_schema', 'pg_catalog', 'duckdb_catalog')
|
|
310
|
+
ORDER BY table_schema, table_name;
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
return await self._execute_query(connection, query)
|
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
"""MySQL database connection and schema introspection."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import ssl
|
|
5
|
+
from typing import Any
|
|
6
|
+
from urllib.parse import parse_qs, urlparse
|
|
7
|
+
|
|
8
|
+
import aiomysql
|
|
9
|
+
|
|
10
|
+
from .base import (
|
|
11
|
+
DEFAULT_QUERY_TIMEOUT,
|
|
12
|
+
BaseDatabaseConnection,
|
|
13
|
+
BaseSchemaIntrospector,
|
|
14
|
+
QueryTimeoutError,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MySQLConnection(BaseDatabaseConnection):
|
|
19
|
+
"""MySQL database connection using aiomysql."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, connection_string: str):
|
|
22
|
+
super().__init__(connection_string)
|
|
23
|
+
self._pool: aiomysql.Pool | None = None
|
|
24
|
+
self._parse_connection_string()
|
|
25
|
+
|
|
26
|
+
def _parse_connection_string(self):
|
|
27
|
+
"""Parse MySQL connection string into components."""
|
|
28
|
+
parsed = urlparse(self.connection_string)
|
|
29
|
+
self.host = parsed.hostname or "localhost"
|
|
30
|
+
self.port = parsed.port or 3306
|
|
31
|
+
self.database = parsed.path.lstrip("/") if parsed.path else ""
|
|
32
|
+
self.user = parsed.username or ""
|
|
33
|
+
self.password = parsed.password or ""
|
|
34
|
+
|
|
35
|
+
# Parse SSL parameters
|
|
36
|
+
self.ssl_params = {}
|
|
37
|
+
if parsed.query:
|
|
38
|
+
params = parse_qs(parsed.query)
|
|
39
|
+
|
|
40
|
+
ssl_mode = params.get("ssl_mode", [None])[0]
|
|
41
|
+
if ssl_mode:
|
|
42
|
+
# Map SSL modes to aiomysql SSL parameters
|
|
43
|
+
if ssl_mode.upper() == "DISABLED":
|
|
44
|
+
self.ssl_params["ssl"] = None
|
|
45
|
+
elif ssl_mode.upper() in [
|
|
46
|
+
"PREFERRED",
|
|
47
|
+
"REQUIRED",
|
|
48
|
+
"VERIFY_CA",
|
|
49
|
+
"VERIFY_IDENTITY",
|
|
50
|
+
]:
|
|
51
|
+
ssl_context = ssl.create_default_context()
|
|
52
|
+
|
|
53
|
+
if ssl_mode.upper() == "REQUIRED":
|
|
54
|
+
ssl_context.check_hostname = False
|
|
55
|
+
ssl_context.verify_mode = ssl.CERT_NONE
|
|
56
|
+
elif ssl_mode.upper() == "VERIFY_CA":
|
|
57
|
+
ssl_context.check_hostname = False
|
|
58
|
+
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
|
59
|
+
elif ssl_mode.upper() == "VERIFY_IDENTITY":
|
|
60
|
+
ssl_context.check_hostname = True
|
|
61
|
+
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
|
62
|
+
|
|
63
|
+
# Load certificates if provided
|
|
64
|
+
ssl_ca = params.get("ssl_ca", [None])[0]
|
|
65
|
+
ssl_cert = params.get("ssl_cert", [None])[0]
|
|
66
|
+
ssl_key = params.get("ssl_key", [None])[0]
|
|
67
|
+
|
|
68
|
+
if ssl_ca:
|
|
69
|
+
ssl_context.load_verify_locations(ssl_ca)
|
|
70
|
+
|
|
71
|
+
if ssl_cert and ssl_key:
|
|
72
|
+
ssl_context.load_cert_chain(ssl_cert, ssl_key)
|
|
73
|
+
|
|
74
|
+
self.ssl_params["ssl"] = ssl_context
|
|
75
|
+
|
|
76
|
+
async def get_pool(self) -> aiomysql.Pool:
|
|
77
|
+
"""Get or create connection pool."""
|
|
78
|
+
if self._pool is None:
|
|
79
|
+
pool_kwargs = {
|
|
80
|
+
"host": self.host,
|
|
81
|
+
"port": self.port,
|
|
82
|
+
"user": self.user,
|
|
83
|
+
"password": self.password,
|
|
84
|
+
"db": self.database,
|
|
85
|
+
"minsize": 1,
|
|
86
|
+
"maxsize": 10,
|
|
87
|
+
"autocommit": False,
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
# Add SSL parameters if configured
|
|
91
|
+
pool_kwargs.update(self.ssl_params)
|
|
92
|
+
|
|
93
|
+
self._pool = await aiomysql.create_pool(**pool_kwargs)
|
|
94
|
+
return self._pool
|
|
95
|
+
|
|
96
|
+
async def close(self):
|
|
97
|
+
"""Close the connection pool."""
|
|
98
|
+
if self._pool:
|
|
99
|
+
self._pool.close()
|
|
100
|
+
await self._pool.wait_closed()
|
|
101
|
+
self._pool = None
|
|
102
|
+
|
|
103
|
+
async def execute_query(
|
|
104
|
+
self, query: str, *args, timeout: float | None = None
|
|
105
|
+
) -> list[dict[str, Any]]:
|
|
106
|
+
"""Execute a query and return results as list of dicts.
|
|
107
|
+
|
|
108
|
+
All queries run in a transaction that is rolled back at the end,
|
|
109
|
+
ensuring no changes are persisted to the database.
|
|
110
|
+
"""
|
|
111
|
+
effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
|
|
112
|
+
pool = await self.get_pool()
|
|
113
|
+
|
|
114
|
+
async with pool.acquire() as conn:
|
|
115
|
+
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|
116
|
+
# Start transaction
|
|
117
|
+
await conn.begin()
|
|
118
|
+
try:
|
|
119
|
+
# Set server-side timeout if specified
|
|
120
|
+
if effective_timeout:
|
|
121
|
+
# Clamp timeout to sane range (10ms to 5 minutes) and validate
|
|
122
|
+
timeout_ms = max(10, min(int(effective_timeout * 1000), 300000))
|
|
123
|
+
await cursor.execute(
|
|
124
|
+
f"SET SESSION MAX_EXECUTION_TIME = {timeout_ms}"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Execute query with client-side timeout
|
|
128
|
+
if effective_timeout:
|
|
129
|
+
await asyncio.wait_for(
|
|
130
|
+
cursor.execute(query, args if args else None),
|
|
131
|
+
timeout=effective_timeout,
|
|
132
|
+
)
|
|
133
|
+
rows = await asyncio.wait_for(
|
|
134
|
+
cursor.fetchall(), timeout=effective_timeout
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
await cursor.execute(query, args if args else None)
|
|
138
|
+
rows = await cursor.fetchall()
|
|
139
|
+
|
|
140
|
+
return [dict(row) for row in rows]
|
|
141
|
+
except asyncio.TimeoutError as exc:
|
|
142
|
+
raise QueryTimeoutError(effective_timeout or 0) from exc
|
|
143
|
+
finally:
|
|
144
|
+
# Always rollback to ensure no changes are committed
|
|
145
|
+
await conn.rollback()
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class MySQLSchemaIntrospector(BaseSchemaIntrospector):
|
|
149
|
+
"""MySQL-specific schema introspection."""
|
|
150
|
+
|
|
151
|
+
def _build_table_filter_clause(self, tables: list) -> tuple[str, list]:
|
|
152
|
+
"""Build row constructor with bind parameters for table filtering.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
tables: List of table dictionaries with table_schema and table_name keys
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Tuple of (placeholders, params) for use in SQL queries
|
|
159
|
+
"""
|
|
160
|
+
if not tables:
|
|
161
|
+
return "", []
|
|
162
|
+
|
|
163
|
+
table_pairs = [(table["table_schema"], table["table_name"]) for table in tables]
|
|
164
|
+
placeholders = ", ".join(["(%s, %s)"] * len(table_pairs))
|
|
165
|
+
params = [value for pair in table_pairs for value in pair]
|
|
166
|
+
return placeholders, params
|
|
167
|
+
|
|
168
|
+
async def get_tables_info(
|
|
169
|
+
self, connection, table_pattern: str | None = None
|
|
170
|
+
) -> dict[str, Any]:
|
|
171
|
+
"""Get tables information for MySQL."""
|
|
172
|
+
pool = await connection.get_pool()
|
|
173
|
+
async with pool.acquire() as conn:
|
|
174
|
+
async with conn.cursor() as cursor:
|
|
175
|
+
# Build WHERE clause for filtering
|
|
176
|
+
where_conditions = [
|
|
177
|
+
"table_schema NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')"
|
|
178
|
+
]
|
|
179
|
+
params = []
|
|
180
|
+
|
|
181
|
+
if table_pattern:
|
|
182
|
+
# Support patterns like 'schema.table' or just 'table'
|
|
183
|
+
if "." in table_pattern:
|
|
184
|
+
schema_pattern, table_name_pattern = table_pattern.split(".", 1)
|
|
185
|
+
where_conditions.append(
|
|
186
|
+
"(table_schema LIKE %s AND table_name LIKE %s)"
|
|
187
|
+
)
|
|
188
|
+
params.extend([schema_pattern, table_name_pattern])
|
|
189
|
+
else:
|
|
190
|
+
where_conditions.append(
|
|
191
|
+
"(table_name LIKE %s OR CONCAT(table_schema, '.', table_name) LIKE %s)"
|
|
192
|
+
)
|
|
193
|
+
params.extend([table_pattern, table_pattern])
|
|
194
|
+
|
|
195
|
+
# Get tables
|
|
196
|
+
tables_query = f"""
|
|
197
|
+
SELECT
|
|
198
|
+
table_schema,
|
|
199
|
+
table_name,
|
|
200
|
+
table_type
|
|
201
|
+
FROM information_schema.tables
|
|
202
|
+
WHERE {" AND ".join(where_conditions)}
|
|
203
|
+
ORDER BY table_schema, table_name;
|
|
204
|
+
"""
|
|
205
|
+
await cursor.execute(tables_query, params)
|
|
206
|
+
return await cursor.fetchall()
|
|
207
|
+
|
|
208
|
+
async def get_columns_info(self, connection, tables: list) -> list:
|
|
209
|
+
"""Get columns information for MySQL."""
|
|
210
|
+
if not tables:
|
|
211
|
+
return []
|
|
212
|
+
|
|
213
|
+
pool = await connection.get_pool()
|
|
214
|
+
async with pool.acquire() as conn:
|
|
215
|
+
async with conn.cursor() as cursor:
|
|
216
|
+
placeholders, params = self._build_table_filter_clause(tables)
|
|
217
|
+
|
|
218
|
+
columns_query = f"""
|
|
219
|
+
SELECT
|
|
220
|
+
c.table_schema,
|
|
221
|
+
c.table_name,
|
|
222
|
+
c.column_name,
|
|
223
|
+
c.data_type,
|
|
224
|
+
c.is_nullable,
|
|
225
|
+
c.column_default,
|
|
226
|
+
c.character_maximum_length,
|
|
227
|
+
c.numeric_precision,
|
|
228
|
+
c.numeric_scale
|
|
229
|
+
FROM information_schema.columns c
|
|
230
|
+
WHERE (c.table_schema, c.table_name) IN ({placeholders})
|
|
231
|
+
ORDER BY c.table_schema, c.table_name, c.ordinal_position;
|
|
232
|
+
"""
|
|
233
|
+
await cursor.execute(columns_query, params)
|
|
234
|
+
return await cursor.fetchall()
|
|
235
|
+
|
|
236
|
+
async def get_foreign_keys_info(self, connection, tables: list) -> list:
|
|
237
|
+
"""Get foreign keys information for MySQL."""
|
|
238
|
+
if not tables:
|
|
239
|
+
return []
|
|
240
|
+
|
|
241
|
+
pool = await connection.get_pool()
|
|
242
|
+
async with pool.acquire() as conn:
|
|
243
|
+
async with conn.cursor() as cursor:
|
|
244
|
+
placeholders, params = self._build_table_filter_clause(tables)
|
|
245
|
+
|
|
246
|
+
fk_query = f"""
|
|
247
|
+
SELECT
|
|
248
|
+
tc.table_schema,
|
|
249
|
+
tc.table_name,
|
|
250
|
+
kcu.column_name,
|
|
251
|
+
rc.unique_constraint_schema AS foreign_table_schema,
|
|
252
|
+
rc.referenced_table_name AS foreign_table_name,
|
|
253
|
+
kcu.referenced_column_name AS foreign_column_name
|
|
254
|
+
FROM information_schema.table_constraints AS tc
|
|
255
|
+
JOIN information_schema.key_column_usage AS kcu
|
|
256
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
257
|
+
AND tc.table_schema = kcu.table_schema
|
|
258
|
+
JOIN information_schema.referential_constraints AS rc
|
|
259
|
+
ON tc.constraint_name = rc.constraint_name
|
|
260
|
+
AND tc.table_schema = rc.constraint_schema
|
|
261
|
+
WHERE tc.constraint_type = 'FOREIGN KEY'
|
|
262
|
+
AND (tc.table_schema, tc.table_name) IN ({placeholders});
|
|
263
|
+
"""
|
|
264
|
+
await cursor.execute(fk_query, params)
|
|
265
|
+
return await cursor.fetchall()
|
|
266
|
+
|
|
267
|
+
async def get_primary_keys_info(self, connection, tables: list) -> list:
|
|
268
|
+
"""Get primary keys information for MySQL."""
|
|
269
|
+
if not tables:
|
|
270
|
+
return []
|
|
271
|
+
|
|
272
|
+
pool = await connection.get_pool()
|
|
273
|
+
async with pool.acquire() as conn:
|
|
274
|
+
async with conn.cursor() as cursor:
|
|
275
|
+
placeholders, params = self._build_table_filter_clause(tables)
|
|
276
|
+
|
|
277
|
+
pk_query = f"""
|
|
278
|
+
SELECT
|
|
279
|
+
tc.table_schema,
|
|
280
|
+
tc.table_name,
|
|
281
|
+
kcu.column_name
|
|
282
|
+
FROM information_schema.table_constraints AS tc
|
|
283
|
+
JOIN information_schema.key_column_usage AS kcu
|
|
284
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
285
|
+
AND tc.table_schema = kcu.table_schema
|
|
286
|
+
WHERE tc.constraint_type = 'PRIMARY KEY'
|
|
287
|
+
AND (tc.table_schema, tc.table_name) IN ({placeholders})
|
|
288
|
+
ORDER BY tc.table_schema, tc.table_name, kcu.ordinal_position;
|
|
289
|
+
"""
|
|
290
|
+
await cursor.execute(pk_query, params)
|
|
291
|
+
return await cursor.fetchall()
|
|
292
|
+
|
|
293
|
+
async def get_indexes_info(self, connection, tables: list) -> list:
|
|
294
|
+
"""Get indexes information for MySQL."""
|
|
295
|
+
if not tables:
|
|
296
|
+
return []
|
|
297
|
+
|
|
298
|
+
pool = await connection.get_pool()
|
|
299
|
+
async with pool.acquire() as conn:
|
|
300
|
+
async with conn.cursor() as cursor:
|
|
301
|
+
placeholders, params = self._build_table_filter_clause(tables)
|
|
302
|
+
|
|
303
|
+
idx_query = f"""
|
|
304
|
+
SELECT
|
|
305
|
+
TABLE_SCHEMA AS table_schema,
|
|
306
|
+
TABLE_NAME AS table_name,
|
|
307
|
+
INDEX_NAME AS index_name,
|
|
308
|
+
(NON_UNIQUE = 0) AS is_unique,
|
|
309
|
+
INDEX_TYPE AS index_type,
|
|
310
|
+
GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX) AS column_names
|
|
311
|
+
FROM INFORMATION_SCHEMA.STATISTICS
|
|
312
|
+
WHERE (TABLE_SCHEMA, TABLE_NAME) IN ({placeholders})
|
|
313
|
+
GROUP BY table_schema, table_name, index_name, is_unique, index_type
|
|
314
|
+
ORDER BY table_schema, table_name, index_name;
|
|
315
|
+
"""
|
|
316
|
+
await cursor.execute(idx_query, params)
|
|
317
|
+
return await cursor.fetchall()
|
|
318
|
+
|
|
319
|
+
async def list_tables_info(self, connection) -> list[dict[str, Any]]:
|
|
320
|
+
"""Get list of tables with basic information for MySQL."""
|
|
321
|
+
pool = await connection.get_pool()
|
|
322
|
+
async with pool.acquire() as conn:
|
|
323
|
+
async with conn.cursor() as cursor:
|
|
324
|
+
# Get tables without row counts for better performance
|
|
325
|
+
tables_query = """
|
|
326
|
+
SELECT
|
|
327
|
+
t.table_schema,
|
|
328
|
+
t.table_name,
|
|
329
|
+
t.table_type
|
|
330
|
+
FROM information_schema.tables t
|
|
331
|
+
WHERE t.table_schema NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
|
|
332
|
+
ORDER BY t.table_schema, t.table_name;
|
|
333
|
+
"""
|
|
334
|
+
await cursor.execute(tables_query)
|
|
335
|
+
rows = await cursor.fetchall()
|
|
336
|
+
|
|
337
|
+
# Convert rows to dictionaries
|
|
338
|
+
return [
|
|
339
|
+
{
|
|
340
|
+
"table_schema": row["table_schema"],
|
|
341
|
+
"table_name": row["table_name"],
|
|
342
|
+
"table_type": row["table_type"],
|
|
343
|
+
}
|
|
344
|
+
for row in rows
|
|
345
|
+
]
|