sql-query-mcp 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_query_mcp/__init__.py +5 -0
- sql_query_mcp/__main__.py +5 -0
- sql_query_mcp/adapters/__init__.py +15 -0
- sql_query_mcp/adapters/mysql.py +171 -0
- sql_query_mcp/adapters/postgres.py +180 -0
- sql_query_mcp/app.py +105 -0
- sql_query_mcp/audit.py +42 -0
- sql_query_mcp/config.py +255 -0
- sql_query_mcp/errors.py +34 -0
- sql_query_mcp/executor.py +243 -0
- sql_query_mcp/introspection.py +225 -0
- sql_query_mcp/namespace.py +48 -0
- sql_query_mcp/registry.py +67 -0
- sql_query_mcp/release_metadata.py +93 -0
- sql_query_mcp/validator.py +128 -0
- sql_query_mcp-0.1.1.dist-info/METADATA +235 -0
- sql_query_mcp-0.1.1.dist-info/RECORD +21 -0
- sql_query_mcp-0.1.1.dist-info/WHEEL +5 -0
- sql_query_mcp-0.1.1.dist-info/entry_points.txt +2 -0
- sql_query_mcp-0.1.1.dist-info/licenses/LICENSE +21 -0
- sql_query_mcp-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Engine adapters for sql-query-mcp."""
|
|
2
|
+
|
|
3
|
+
__all__ = ["MySQLAdapter", "PostgresAdapter"]
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def __getattr__(name: str):
|
|
7
|
+
if name == "MySQLAdapter":
|
|
8
|
+
from .mysql import MySQLAdapter
|
|
9
|
+
|
|
10
|
+
return MySQLAdapter
|
|
11
|
+
if name == "PostgresAdapter":
|
|
12
|
+
from .postgres import PostgresAdapter
|
|
13
|
+
|
|
14
|
+
return PostgresAdapter
|
|
15
|
+
raise AttributeError(name)
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""MySQL adapter."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from typing import Iterator, List
|
|
8
|
+
from urllib.parse import parse_qs, unquote, urlparse
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import pymysql
|
|
12
|
+
from pymysql.cursors import DictCursor
|
|
13
|
+
except ImportError: # pragma: no cover - runtime dependency
|
|
14
|
+
pymysql = None
|
|
15
|
+
DictCursor = None
|
|
16
|
+
|
|
17
|
+
from ..errors import ConfigurationError, SecurityError
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MySQLAdapter:
|
|
21
|
+
engine = "mysql"
|
|
22
|
+
|
|
23
|
+
@contextmanager
|
|
24
|
+
def connection(self, connection_id: str, dsn: str) -> Iterator[object]:
|
|
25
|
+
if pymysql is None or DictCursor is None:
|
|
26
|
+
raise ConfigurationError("缺少 PyMySQL 依赖,请先安装项目依赖。")
|
|
27
|
+
|
|
28
|
+
conn = pymysql.connect(
|
|
29
|
+
autocommit=True,
|
|
30
|
+
cursorclass=DictCursor,
|
|
31
|
+
**self._parse_dsn(dsn),
|
|
32
|
+
)
|
|
33
|
+
try:
|
|
34
|
+
yield conn
|
|
35
|
+
finally:
|
|
36
|
+
conn.close()
|
|
37
|
+
|
|
38
|
+
def close(self) -> None:
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
def set_statement_timeout(self, conn: object, timeout_ms: int) -> None:
|
|
42
|
+
with conn.cursor() as cur:
|
|
43
|
+
cur.execute("SET SESSION max_execution_time = %s", (int(timeout_ms),))
|
|
44
|
+
|
|
45
|
+
def list_databases(self, conn: object) -> List[str]:
|
|
46
|
+
with conn.cursor() as cur:
|
|
47
|
+
cur.execute(
|
|
48
|
+
"""
|
|
49
|
+
SELECT schema_name AS database_name
|
|
50
|
+
FROM information_schema.schemata
|
|
51
|
+
WHERE schema_name NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')
|
|
52
|
+
ORDER BY schema_name
|
|
53
|
+
"""
|
|
54
|
+
)
|
|
55
|
+
return [row["database_name"] for row in cur.fetchall()]
|
|
56
|
+
|
|
57
|
+
def list_tables(self, conn: object, database: str):
|
|
58
|
+
with conn.cursor() as cur:
|
|
59
|
+
cur.execute(
|
|
60
|
+
"""
|
|
61
|
+
SELECT table_schema AS database_name, table_name, table_type
|
|
62
|
+
FROM information_schema.tables
|
|
63
|
+
WHERE table_schema = %s
|
|
64
|
+
ORDER BY table_name
|
|
65
|
+
""",
|
|
66
|
+
(database,),
|
|
67
|
+
)
|
|
68
|
+
return cur.fetchall()
|
|
69
|
+
|
|
70
|
+
def describe_table(self, conn: object, database: str, table_name: str):
|
|
71
|
+
with conn.cursor() as cur:
|
|
72
|
+
cur.execute(
|
|
73
|
+
"""
|
|
74
|
+
SELECT column_name, column_type, is_nullable, column_default, extra, column_key, ordinal_position
|
|
75
|
+
FROM information_schema.columns
|
|
76
|
+
WHERE table_schema = %s AND table_name = %s
|
|
77
|
+
ORDER BY ordinal_position
|
|
78
|
+
""",
|
|
79
|
+
(database, table_name),
|
|
80
|
+
)
|
|
81
|
+
columns = cur.fetchall()
|
|
82
|
+
cur.execute(
|
|
83
|
+
"""
|
|
84
|
+
SELECT index_name, non_unique, seq_in_index, column_name
|
|
85
|
+
FROM information_schema.statistics
|
|
86
|
+
WHERE table_schema = %s AND table_name = %s
|
|
87
|
+
ORDER BY index_name, seq_in_index
|
|
88
|
+
""",
|
|
89
|
+
(database, table_name),
|
|
90
|
+
)
|
|
91
|
+
index_rows = cur.fetchall()
|
|
92
|
+
|
|
93
|
+
if not columns:
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
return {
|
|
97
|
+
"columns": [
|
|
98
|
+
{
|
|
99
|
+
"column_name": row["column_name"],
|
|
100
|
+
"data_type": row["column_type"],
|
|
101
|
+
"udt_name": None,
|
|
102
|
+
"nullable": row["is_nullable"] == "YES",
|
|
103
|
+
"default": row["column_default"],
|
|
104
|
+
"primary_key": row["column_key"] == "PRI",
|
|
105
|
+
"extra": row["extra"],
|
|
106
|
+
}
|
|
107
|
+
for row in columns
|
|
108
|
+
],
|
|
109
|
+
"indexes": self._normalize_indexes(index_rows),
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
def build_sample_query(self, database: str, table_name: str, sentinel_limit: int) -> str:
|
|
113
|
+
return (
|
|
114
|
+
f"SELECT * FROM {self._quote_identifier(database)}."
|
|
115
|
+
f"{self._quote_identifier(table_name)} LIMIT {int(sentinel_limit)}"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def build_explain_query(self, sql_text: str, analyze: bool = False) -> str:
|
|
119
|
+
if analyze:
|
|
120
|
+
raise SecurityError("MySQL 首版不支持 analyze=True。")
|
|
121
|
+
return f"EXPLAIN FORMAT=JSON {sql_text}"
|
|
122
|
+
|
|
123
|
+
def extract_plan(self, rows):
|
|
124
|
+
if not rows:
|
|
125
|
+
return []
|
|
126
|
+
plan = rows[0].get("EXPLAIN", [])
|
|
127
|
+
if isinstance(plan, str):
|
|
128
|
+
try:
|
|
129
|
+
return json.loads(plan)
|
|
130
|
+
except json.JSONDecodeError:
|
|
131
|
+
return plan
|
|
132
|
+
return plan
|
|
133
|
+
|
|
134
|
+
def column_names(self, description) -> List[str]:
|
|
135
|
+
return [column[0] for column in (description or [])]
|
|
136
|
+
|
|
137
|
+
def _parse_dsn(self, dsn: str) -> dict:
|
|
138
|
+
parsed = urlparse(dsn)
|
|
139
|
+
if parsed.scheme not in {"mysql", "mysql+pymysql"}:
|
|
140
|
+
raise ConfigurationError(f"MySQL DSN 必须使用 mysql:// 或 mysql+pymysql://,当前为 {parsed.scheme}")
|
|
141
|
+
|
|
142
|
+
query_params = {key: values[-1] for key, values in parse_qs(parsed.query).items()}
|
|
143
|
+
connect_args = {
|
|
144
|
+
"host": parsed.hostname or "localhost",
|
|
145
|
+
"user": unquote(parsed.username) if parsed.username else None,
|
|
146
|
+
"password": unquote(parsed.password) if parsed.password else None,
|
|
147
|
+
"port": parsed.port or 3306,
|
|
148
|
+
"database": parsed.path.lstrip("/") or None,
|
|
149
|
+
"charset": query_params.get("charset", "utf8mb4"),
|
|
150
|
+
}
|
|
151
|
+
return {key: value for key, value in connect_args.items() if value is not None}
|
|
152
|
+
|
|
153
|
+
def _quote_identifier(self, value: str) -> str:
|
|
154
|
+
return "`" + value.replace("`", "``") + "`"
|
|
155
|
+
|
|
156
|
+
def _normalize_indexes(self, rows: List[dict]) -> List[dict]:
|
|
157
|
+
grouped = {}
|
|
158
|
+
for row in rows:
|
|
159
|
+
index_name = row["index_name"]
|
|
160
|
+
item = grouped.setdefault(
|
|
161
|
+
index_name,
|
|
162
|
+
{
|
|
163
|
+
"index_name": index_name,
|
|
164
|
+
"columns": [],
|
|
165
|
+
"unique": row["non_unique"] == 0,
|
|
166
|
+
"primary_key": index_name == "PRIMARY",
|
|
167
|
+
"definition": None,
|
|
168
|
+
},
|
|
169
|
+
)
|
|
170
|
+
item["columns"].append(row["column_name"])
|
|
171
|
+
return [grouped[name] for name in sorted(grouped)]
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""PostgreSQL adapter."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from typing import Iterator, List
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from psycopg import sql
|
|
10
|
+
from psycopg.rows import dict_row
|
|
11
|
+
from psycopg_pool import ConnectionPool
|
|
12
|
+
except ImportError: # pragma: no cover - runtime dependency
|
|
13
|
+
sql = None
|
|
14
|
+
dict_row = None
|
|
15
|
+
ConnectionPool = None
|
|
16
|
+
|
|
17
|
+
from ..errors import ConfigurationError
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PostgresAdapter:
|
|
21
|
+
engine = "postgres"
|
|
22
|
+
|
|
23
|
+
def __init__(self) -> None:
|
|
24
|
+
self._pools = {}
|
|
25
|
+
|
|
26
|
+
@contextmanager
|
|
27
|
+
def connection(self, connection_id: str, dsn: str) -> Iterator[object]:
|
|
28
|
+
pool = self._get_pool(connection_id, dsn)
|
|
29
|
+
with pool.connection() as conn:
|
|
30
|
+
yield conn
|
|
31
|
+
|
|
32
|
+
def close(self) -> None:
|
|
33
|
+
for pool in self._pools.values():
|
|
34
|
+
pool.close()
|
|
35
|
+
|
|
36
|
+
def set_statement_timeout(self, conn: object, timeout_ms: int) -> None:
|
|
37
|
+
with conn.cursor() as cur:
|
|
38
|
+
cur.execute("SELECT set_config('statement_timeout', %s, false)", (str(timeout_ms),))
|
|
39
|
+
|
|
40
|
+
def list_schemas(self, conn: object) -> List[str]:
|
|
41
|
+
with conn.cursor() as cur:
|
|
42
|
+
cur.execute(
|
|
43
|
+
"""
|
|
44
|
+
SELECT schema_name
|
|
45
|
+
FROM information_schema.schemata
|
|
46
|
+
WHERE schema_name NOT IN ('information_schema')
|
|
47
|
+
AND schema_name NOT LIKE 'pg_%'
|
|
48
|
+
ORDER BY schema_name
|
|
49
|
+
"""
|
|
50
|
+
)
|
|
51
|
+
return [row["schema_name"] for row in cur.fetchall()]
|
|
52
|
+
|
|
53
|
+
def list_tables(self, conn: object, schema: str):
|
|
54
|
+
with conn.cursor() as cur:
|
|
55
|
+
cur.execute(
|
|
56
|
+
"""
|
|
57
|
+
SELECT table_schema AS schema, table_name, table_type
|
|
58
|
+
FROM information_schema.tables
|
|
59
|
+
WHERE table_schema = %s
|
|
60
|
+
ORDER BY table_name
|
|
61
|
+
""",
|
|
62
|
+
(schema,),
|
|
63
|
+
)
|
|
64
|
+
return cur.fetchall()
|
|
65
|
+
|
|
66
|
+
def describe_table(self, conn: object, schema: str, table_name: str):
|
|
67
|
+
with conn.cursor() as cur:
|
|
68
|
+
cur.execute(
|
|
69
|
+
"""
|
|
70
|
+
SELECT column_name, data_type, udt_name, is_nullable, column_default, ordinal_position
|
|
71
|
+
FROM information_schema.columns
|
|
72
|
+
WHERE table_schema = %s AND table_name = %s
|
|
73
|
+
ORDER BY ordinal_position
|
|
74
|
+
""",
|
|
75
|
+
(schema, table_name),
|
|
76
|
+
)
|
|
77
|
+
columns = cur.fetchall()
|
|
78
|
+
cur.execute(
|
|
79
|
+
"""
|
|
80
|
+
SELECT kcu.column_name
|
|
81
|
+
FROM information_schema.table_constraints tc
|
|
82
|
+
JOIN information_schema.key_column_usage kcu
|
|
83
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
84
|
+
AND tc.table_schema = kcu.table_schema
|
|
85
|
+
WHERE tc.constraint_type = 'PRIMARY KEY'
|
|
86
|
+
AND tc.table_schema = %s
|
|
87
|
+
AND tc.table_name = %s
|
|
88
|
+
ORDER BY kcu.ordinal_position
|
|
89
|
+
""",
|
|
90
|
+
(schema, table_name),
|
|
91
|
+
)
|
|
92
|
+
primary_keys = {row["column_name"] for row in cur.fetchall()}
|
|
93
|
+
cur.execute(
|
|
94
|
+
"""
|
|
95
|
+
SELECT
|
|
96
|
+
idx.relname AS index_name,
|
|
97
|
+
ix.indisunique AS is_unique,
|
|
98
|
+
ix.indisprimary AS is_primary,
|
|
99
|
+
pg_get_indexdef(ix.indexrelid) AS definition,
|
|
100
|
+
COALESCE(
|
|
101
|
+
array_agg(att.attname ORDER BY keys.ordinality)
|
|
102
|
+
FILTER (WHERE att.attname IS NOT NULL),
|
|
103
|
+
ARRAY[]::text[]
|
|
104
|
+
) AS columns
|
|
105
|
+
FROM pg_class tbl
|
|
106
|
+
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
|
|
107
|
+
JOIN pg_index ix ON ix.indrelid = tbl.oid
|
|
108
|
+
JOIN pg_class idx ON idx.oid = ix.indexrelid
|
|
109
|
+
LEFT JOIN LATERAL unnest(ix.indkey) WITH ORDINALITY AS keys(attnum, ordinality) ON TRUE
|
|
110
|
+
LEFT JOIN pg_attribute att
|
|
111
|
+
ON att.attrelid = tbl.oid
|
|
112
|
+
AND att.attnum = keys.attnum
|
|
113
|
+
WHERE ns.nspname = %s
|
|
114
|
+
AND tbl.relname = %s
|
|
115
|
+
GROUP BY idx.relname, ix.indisunique, ix.indisprimary, ix.indexrelid
|
|
116
|
+
ORDER BY idx.relname
|
|
117
|
+
""",
|
|
118
|
+
(schema, table_name),
|
|
119
|
+
)
|
|
120
|
+
index_rows = cur.fetchall()
|
|
121
|
+
|
|
122
|
+
if not columns:
|
|
123
|
+
return None
|
|
124
|
+
|
|
125
|
+
return {
|
|
126
|
+
"columns": [
|
|
127
|
+
{
|
|
128
|
+
"column_name": row["column_name"],
|
|
129
|
+
"data_type": row["data_type"],
|
|
130
|
+
"udt_name": row["udt_name"],
|
|
131
|
+
"nullable": row["is_nullable"] == "YES",
|
|
132
|
+
"default": row["column_default"],
|
|
133
|
+
"primary_key": row["column_name"] in primary_keys,
|
|
134
|
+
}
|
|
135
|
+
for row in columns
|
|
136
|
+
],
|
|
137
|
+
"indexes": [
|
|
138
|
+
{
|
|
139
|
+
"index_name": row["index_name"],
|
|
140
|
+
"columns": row["columns"],
|
|
141
|
+
"unique": row["is_unique"],
|
|
142
|
+
"primary_key": row["is_primary"],
|
|
143
|
+
"definition": row["definition"],
|
|
144
|
+
}
|
|
145
|
+
for row in index_rows
|
|
146
|
+
],
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
def build_sample_query(self, schema: str, table_name: str, sentinel_limit: int):
|
|
150
|
+
if sql is None:
|
|
151
|
+
raise ConfigurationError("缺少 psycopg 依赖,请先安装项目依赖。")
|
|
152
|
+
return sql.SQL("SELECT * FROM {}.{} LIMIT {}").format(
|
|
153
|
+
sql.Identifier(schema),
|
|
154
|
+
sql.Identifier(table_name),
|
|
155
|
+
sql.Literal(sentinel_limit),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def build_explain_query(self, sql_text: str, analyze: bool = False) -> str:
|
|
159
|
+
return f"EXPLAIN (FORMAT JSON, ANALYZE {'TRUE' if analyze else 'FALSE'}) {sql_text}"
|
|
160
|
+
|
|
161
|
+
def extract_plan(self, rows):
|
|
162
|
+
return rows[0].get("QUERY PLAN", []) if rows else []
|
|
163
|
+
|
|
164
|
+
def column_names(self, description) -> List[str]:
|
|
165
|
+
return [column.name for column in (description or [])]
|
|
166
|
+
|
|
167
|
+
def _get_pool(self, connection_id: str, dsn: str) -> ConnectionPool:
|
|
168
|
+
if ConnectionPool is None or dict_row is None:
|
|
169
|
+
raise ConfigurationError("缺少 psycopg / psycopg-pool 依赖,请先安装项目依赖。")
|
|
170
|
+
pool = self._pools.get(connection_id)
|
|
171
|
+
if pool is None:
|
|
172
|
+
pool = ConnectionPool(
|
|
173
|
+
conninfo=dsn,
|
|
174
|
+
min_size=0,
|
|
175
|
+
max_size=4,
|
|
176
|
+
open=True,
|
|
177
|
+
kwargs={"autocommit": True, "row_factory": dict_row},
|
|
178
|
+
)
|
|
179
|
+
self._pools[connection_id] = pool
|
|
180
|
+
return pool
|
sql_query_mcp/app.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""FastMCP application for stateless SQL queries."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from mcp.server.fastmcp import FastMCP
|
|
8
|
+
|
|
9
|
+
from .audit import AuditLogger
|
|
10
|
+
from .config import load_config
|
|
11
|
+
from .errors import SqlQueryMCPError
|
|
12
|
+
from .executor import QueryExecutor
|
|
13
|
+
from .introspection import MetadataService
|
|
14
|
+
from .registry import ConnectionRegistry
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def create_app() -> FastMCP:
|
|
18
|
+
app_config = load_config()
|
|
19
|
+
registry = ConnectionRegistry(app_config)
|
|
20
|
+
audit_logger = AuditLogger(app_config.settings.audit_log_path)
|
|
21
|
+
metadata = MetadataService(registry, app_config.settings, audit_logger)
|
|
22
|
+
executor = QueryExecutor(registry, app_config.settings, audit_logger)
|
|
23
|
+
|
|
24
|
+
mcp = FastMCP("sql-query-mcp", json_response=True)
|
|
25
|
+
|
|
26
|
+
@mcp.tool()
|
|
27
|
+
def list_connections() -> dict:
|
|
28
|
+
"""List configured SQL connections by connection_id."""
|
|
29
|
+
|
|
30
|
+
return {"connections": registry.list_connections()}
|
|
31
|
+
|
|
32
|
+
@mcp.tool()
|
|
33
|
+
def list_schemas(connection_id: str) -> dict:
|
|
34
|
+
"""List visible schemas for a PostgreSQL connection."""
|
|
35
|
+
|
|
36
|
+
return _run_tool(lambda: metadata.list_schemas(connection_id))
|
|
37
|
+
|
|
38
|
+
@mcp.tool()
|
|
39
|
+
def list_databases(connection_id: str) -> dict:
|
|
40
|
+
"""List visible databases for a MySQL connection."""
|
|
41
|
+
|
|
42
|
+
return _run_tool(lambda: metadata.list_databases(connection_id))
|
|
43
|
+
|
|
44
|
+
@mcp.tool()
|
|
45
|
+
def list_tables(
|
|
46
|
+
connection_id: str,
|
|
47
|
+
schema: Optional[str] = None,
|
|
48
|
+
database: Optional[str] = None,
|
|
49
|
+
) -> dict:
|
|
50
|
+
"""List tables and views for a resolved PostgreSQL schema or MySQL database."""
|
|
51
|
+
|
|
52
|
+
return _run_tool(lambda: metadata.list_tables(connection_id, schema, database))
|
|
53
|
+
|
|
54
|
+
@mcp.tool()
|
|
55
|
+
def describe_table(
|
|
56
|
+
connection_id: str,
|
|
57
|
+
table_name: str,
|
|
58
|
+
schema: Optional[str] = None,
|
|
59
|
+
database: Optional[str] = None,
|
|
60
|
+
) -> dict:
|
|
61
|
+
"""Describe columns, keys, and indexes for a table."""
|
|
62
|
+
|
|
63
|
+
return _run_tool(lambda: metadata.describe_table(connection_id, table_name, schema, database))
|
|
64
|
+
|
|
65
|
+
@mcp.tool()
|
|
66
|
+
def run_select(connection_id: str, sql: str, limit: Optional[int] = None) -> dict:
|
|
67
|
+
"""Run a read-only SELECT or CTE query."""
|
|
68
|
+
|
|
69
|
+
return _run_tool(lambda: executor.run_select(connection_id, sql, limit))
|
|
70
|
+
|
|
71
|
+
@mcp.tool()
|
|
72
|
+
def explain_query(connection_id: str, sql: str, analyze: bool = False) -> dict:
|
|
73
|
+
"""Run EXPLAIN on a read-only SELECT or CTE query."""
|
|
74
|
+
|
|
75
|
+
return _run_tool(lambda: executor.explain_query(connection_id, sql, analyze))
|
|
76
|
+
|
|
77
|
+
@mcp.tool()
|
|
78
|
+
def get_table_sample(
|
|
79
|
+
connection_id: str,
|
|
80
|
+
table_name: str,
|
|
81
|
+
schema: Optional[str] = None,
|
|
82
|
+
database: Optional[str] = None,
|
|
83
|
+
limit: Optional[int] = None,
|
|
84
|
+
) -> dict:
|
|
85
|
+
"""Fetch a small sample from a table for schema discovery."""
|
|
86
|
+
|
|
87
|
+
return _run_tool(lambda: executor.get_table_sample(connection_id, table_name, schema, database, limit))
|
|
88
|
+
|
|
89
|
+
return mcp
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _run_tool(func):
|
|
93
|
+
try:
|
|
94
|
+
return func()
|
|
95
|
+
except SqlQueryMCPError as exc:
|
|
96
|
+
raise ValueError(str(exc)) from exc
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def main() -> None:
|
|
100
|
+
app = create_app()
|
|
101
|
+
app.run()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
if __name__ == "__main__":
|
|
105
|
+
main()
|
sql_query_mcp/audit.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Audit logging utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AuditLogger:
|
|
12
|
+
"""Write audit records as JSON lines."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, log_path: Path):
|
|
15
|
+
self._log_path = Path(log_path)
|
|
16
|
+
|
|
17
|
+
def log(
|
|
18
|
+
self,
|
|
19
|
+
*,
|
|
20
|
+
tool: str,
|
|
21
|
+
connection_id: Optional[str],
|
|
22
|
+
success: bool,
|
|
23
|
+
duration_ms: int,
|
|
24
|
+
row_count: Optional[int] = None,
|
|
25
|
+
sql_summary: Optional[str] = None,
|
|
26
|
+
error: Optional[str] = None,
|
|
27
|
+
extra: Optional[Dict[str, Any]] = None,
|
|
28
|
+
) -> None:
|
|
29
|
+
record = {
|
|
30
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
31
|
+
"tool": tool,
|
|
32
|
+
"connection_id": connection_id,
|
|
33
|
+
"success": success,
|
|
34
|
+
"duration_ms": duration_ms,
|
|
35
|
+
"row_count": row_count,
|
|
36
|
+
"sql_summary": sql_summary,
|
|
37
|
+
"error": error,
|
|
38
|
+
"extra": extra or {},
|
|
39
|
+
}
|
|
40
|
+
self._log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
41
|
+
with self._log_path.open("a", encoding="utf-8") as handle:
|
|
42
|
+
handle.write(json.dumps(record, ensure_ascii=False) + "\n")
|