sql-query-mcp 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ """sql-query-mcp package."""
2
+
3
+ __all__ = ["__version__"]
4
+
5
+ __version__ = "0.1.0"
@@ -0,0 +1,5 @@
1
+ from .app import main
2
+
3
+
4
+ if __name__ == "__main__":
5
+ main()
@@ -0,0 +1,15 @@
1
+ """Engine adapters for sql-query-mcp."""
2
+
3
+ __all__ = ["MySQLAdapter", "PostgresAdapter"]
4
+
5
+
6
+ def __getattr__(name: str):
7
+ if name == "MySQLAdapter":
8
+ from .mysql import MySQLAdapter
9
+
10
+ return MySQLAdapter
11
+ if name == "PostgresAdapter":
12
+ from .postgres import PostgresAdapter
13
+
14
+ return PostgresAdapter
15
+ raise AttributeError(name)
@@ -0,0 +1,171 @@
1
+ """MySQL adapter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from contextlib import contextmanager
7
+ from typing import Iterator, List
8
+ from urllib.parse import parse_qs, unquote, urlparse
9
+
10
+ try:
11
+ import pymysql
12
+ from pymysql.cursors import DictCursor
13
+ except ImportError: # pragma: no cover - runtime dependency
14
+ pymysql = None
15
+ DictCursor = None
16
+
17
+ from ..errors import ConfigurationError, SecurityError
18
+
19
+
20
+ class MySQLAdapter:
21
+ engine = "mysql"
22
+
23
+ @contextmanager
24
+ def connection(self, connection_id: str, dsn: str) -> Iterator[object]:
25
+ if pymysql is None or DictCursor is None:
26
+ raise ConfigurationError("缺少 PyMySQL 依赖,请先安装项目依赖。")
27
+
28
+ conn = pymysql.connect(
29
+ autocommit=True,
30
+ cursorclass=DictCursor,
31
+ **self._parse_dsn(dsn),
32
+ )
33
+ try:
34
+ yield conn
35
+ finally:
36
+ conn.close()
37
+
38
+ def close(self) -> None:
39
+ return None
40
+
41
+ def set_statement_timeout(self, conn: object, timeout_ms: int) -> None:
42
+ with conn.cursor() as cur:
43
+ cur.execute("SET SESSION max_execution_time = %s", (int(timeout_ms),))
44
+
45
+ def list_databases(self, conn: object) -> List[str]:
46
+ with conn.cursor() as cur:
47
+ cur.execute(
48
+ """
49
+ SELECT schema_name AS database_name
50
+ FROM information_schema.schemata
51
+ WHERE schema_name NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')
52
+ ORDER BY schema_name
53
+ """
54
+ )
55
+ return [row["database_name"] for row in cur.fetchall()]
56
+
57
+ def list_tables(self, conn: object, database: str):
58
+ with conn.cursor() as cur:
59
+ cur.execute(
60
+ """
61
+ SELECT table_schema AS database_name, table_name, table_type
62
+ FROM information_schema.tables
63
+ WHERE table_schema = %s
64
+ ORDER BY table_name
65
+ """,
66
+ (database,),
67
+ )
68
+ return cur.fetchall()
69
+
70
+ def describe_table(self, conn: object, database: str, table_name: str):
71
+ with conn.cursor() as cur:
72
+ cur.execute(
73
+ """
74
+ SELECT column_name, column_type, is_nullable, column_default, extra, column_key, ordinal_position
75
+ FROM information_schema.columns
76
+ WHERE table_schema = %s AND table_name = %s
77
+ ORDER BY ordinal_position
78
+ """,
79
+ (database, table_name),
80
+ )
81
+ columns = cur.fetchall()
82
+ cur.execute(
83
+ """
84
+ SELECT index_name, non_unique, seq_in_index, column_name
85
+ FROM information_schema.statistics
86
+ WHERE table_schema = %s AND table_name = %s
87
+ ORDER BY index_name, seq_in_index
88
+ """,
89
+ (database, table_name),
90
+ )
91
+ index_rows = cur.fetchall()
92
+
93
+ if not columns:
94
+ return None
95
+
96
+ return {
97
+ "columns": [
98
+ {
99
+ "column_name": row["column_name"],
100
+ "data_type": row["column_type"],
101
+ "udt_name": None,
102
+ "nullable": row["is_nullable"] == "YES",
103
+ "default": row["column_default"],
104
+ "primary_key": row["column_key"] == "PRI",
105
+ "extra": row["extra"],
106
+ }
107
+ for row in columns
108
+ ],
109
+ "indexes": self._normalize_indexes(index_rows),
110
+ }
111
+
112
+ def build_sample_query(self, database: str, table_name: str, sentinel_limit: int) -> str:
113
+ return (
114
+ f"SELECT * FROM {self._quote_identifier(database)}."
115
+ f"{self._quote_identifier(table_name)} LIMIT {int(sentinel_limit)}"
116
+ )
117
+
118
+ def build_explain_query(self, sql_text: str, analyze: bool = False) -> str:
119
+ if analyze:
120
+ raise SecurityError("MySQL 首版不支持 analyze=True。")
121
+ return f"EXPLAIN FORMAT=JSON {sql_text}"
122
+
123
+ def extract_plan(self, rows):
124
+ if not rows:
125
+ return []
126
+ plan = rows[0].get("EXPLAIN", [])
127
+ if isinstance(plan, str):
128
+ try:
129
+ return json.loads(plan)
130
+ except json.JSONDecodeError:
131
+ return plan
132
+ return plan
133
+
134
+ def column_names(self, description) -> List[str]:
135
+ return [column[0] for column in (description or [])]
136
+
137
+ def _parse_dsn(self, dsn: str) -> dict:
138
+ parsed = urlparse(dsn)
139
+ if parsed.scheme not in {"mysql", "mysql+pymysql"}:
140
+ raise ConfigurationError(f"MySQL DSN 必须使用 mysql:// 或 mysql+pymysql://,当前为 {parsed.scheme}")
141
+
142
+ query_params = {key: values[-1] for key, values in parse_qs(parsed.query).items()}
143
+ connect_args = {
144
+ "host": parsed.hostname or "localhost",
145
+ "user": unquote(parsed.username) if parsed.username else None,
146
+ "password": unquote(parsed.password) if parsed.password else None,
147
+ "port": parsed.port or 3306,
148
+ "database": parsed.path.lstrip("/") or None,
149
+ "charset": query_params.get("charset", "utf8mb4"),
150
+ }
151
+ return {key: value for key, value in connect_args.items() if value is not None}
152
+
153
+ def _quote_identifier(self, value: str) -> str:
154
+ return "`" + value.replace("`", "``") + "`"
155
+
156
+ def _normalize_indexes(self, rows: List[dict]) -> List[dict]:
157
+ grouped = {}
158
+ for row in rows:
159
+ index_name = row["index_name"]
160
+ item = grouped.setdefault(
161
+ index_name,
162
+ {
163
+ "index_name": index_name,
164
+ "columns": [],
165
+ "unique": row["non_unique"] == 0,
166
+ "primary_key": index_name == "PRIMARY",
167
+ "definition": None,
168
+ },
169
+ )
170
+ item["columns"].append(row["column_name"])
171
+ return [grouped[name] for name in sorted(grouped)]
@@ -0,0 +1,180 @@
1
+ """PostgreSQL adapter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from contextlib import contextmanager
6
+ from typing import Iterator, List
7
+
8
+ try:
9
+ from psycopg import sql
10
+ from psycopg.rows import dict_row
11
+ from psycopg_pool import ConnectionPool
12
+ except ImportError: # pragma: no cover - runtime dependency
13
+ sql = None
14
+ dict_row = None
15
+ ConnectionPool = None
16
+
17
+ from ..errors import ConfigurationError
18
+
19
+
20
+ class PostgresAdapter:
21
+ engine = "postgres"
22
+
23
+ def __init__(self) -> None:
24
+ self._pools = {}
25
+
26
+ @contextmanager
27
+ def connection(self, connection_id: str, dsn: str) -> Iterator[object]:
28
+ pool = self._get_pool(connection_id, dsn)
29
+ with pool.connection() as conn:
30
+ yield conn
31
+
32
+ def close(self) -> None:
33
+ for pool in self._pools.values():
34
+ pool.close()
35
+
36
+ def set_statement_timeout(self, conn: object, timeout_ms: int) -> None:
37
+ with conn.cursor() as cur:
38
+ cur.execute("SELECT set_config('statement_timeout', %s, false)", (str(timeout_ms),))
39
+
40
+ def list_schemas(self, conn: object) -> List[str]:
41
+ with conn.cursor() as cur:
42
+ cur.execute(
43
+ """
44
+ SELECT schema_name
45
+ FROM information_schema.schemata
46
+ WHERE schema_name NOT IN ('information_schema')
47
+ AND schema_name NOT LIKE 'pg_%'
48
+ ORDER BY schema_name
49
+ """
50
+ )
51
+ return [row["schema_name"] for row in cur.fetchall()]
52
+
53
+ def list_tables(self, conn: object, schema: str):
54
+ with conn.cursor() as cur:
55
+ cur.execute(
56
+ """
57
+ SELECT table_schema AS schema, table_name, table_type
58
+ FROM information_schema.tables
59
+ WHERE table_schema = %s
60
+ ORDER BY table_name
61
+ """,
62
+ (schema,),
63
+ )
64
+ return cur.fetchall()
65
+
66
+ def describe_table(self, conn: object, schema: str, table_name: str):
67
+ with conn.cursor() as cur:
68
+ cur.execute(
69
+ """
70
+ SELECT column_name, data_type, udt_name, is_nullable, column_default, ordinal_position
71
+ FROM information_schema.columns
72
+ WHERE table_schema = %s AND table_name = %s
73
+ ORDER BY ordinal_position
74
+ """,
75
+ (schema, table_name),
76
+ )
77
+ columns = cur.fetchall()
78
+ cur.execute(
79
+ """
80
+ SELECT kcu.column_name
81
+ FROM information_schema.table_constraints tc
82
+ JOIN information_schema.key_column_usage kcu
83
+ ON tc.constraint_name = kcu.constraint_name
84
+ AND tc.table_schema = kcu.table_schema
85
+ WHERE tc.constraint_type = 'PRIMARY KEY'
86
+ AND tc.table_schema = %s
87
+ AND tc.table_name = %s
88
+ ORDER BY kcu.ordinal_position
89
+ """,
90
+ (schema, table_name),
91
+ )
92
+ primary_keys = {row["column_name"] for row in cur.fetchall()}
93
+ cur.execute(
94
+ """
95
+ SELECT
96
+ idx.relname AS index_name,
97
+ ix.indisunique AS is_unique,
98
+ ix.indisprimary AS is_primary,
99
+ pg_get_indexdef(ix.indexrelid) AS definition,
100
+ COALESCE(
101
+ array_agg(att.attname ORDER BY keys.ordinality)
102
+ FILTER (WHERE att.attname IS NOT NULL),
103
+ ARRAY[]::text[]
104
+ ) AS columns
105
+ FROM pg_class tbl
106
+ JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
107
+ JOIN pg_index ix ON ix.indrelid = tbl.oid
108
+ JOIN pg_class idx ON idx.oid = ix.indexrelid
109
+ LEFT JOIN LATERAL unnest(ix.indkey) WITH ORDINALITY AS keys(attnum, ordinality) ON TRUE
110
+ LEFT JOIN pg_attribute att
111
+ ON att.attrelid = tbl.oid
112
+ AND att.attnum = keys.attnum
113
+ WHERE ns.nspname = %s
114
+ AND tbl.relname = %s
115
+ GROUP BY idx.relname, ix.indisunique, ix.indisprimary, ix.indexrelid
116
+ ORDER BY idx.relname
117
+ """,
118
+ (schema, table_name),
119
+ )
120
+ index_rows = cur.fetchall()
121
+
122
+ if not columns:
123
+ return None
124
+
125
+ return {
126
+ "columns": [
127
+ {
128
+ "column_name": row["column_name"],
129
+ "data_type": row["data_type"],
130
+ "udt_name": row["udt_name"],
131
+ "nullable": row["is_nullable"] == "YES",
132
+ "default": row["column_default"],
133
+ "primary_key": row["column_name"] in primary_keys,
134
+ }
135
+ for row in columns
136
+ ],
137
+ "indexes": [
138
+ {
139
+ "index_name": row["index_name"],
140
+ "columns": row["columns"],
141
+ "unique": row["is_unique"],
142
+ "primary_key": row["is_primary"],
143
+ "definition": row["definition"],
144
+ }
145
+ for row in index_rows
146
+ ],
147
+ }
148
+
149
+ def build_sample_query(self, schema: str, table_name: str, sentinel_limit: int):
150
+ if sql is None:
151
+ raise ConfigurationError("缺少 psycopg 依赖,请先安装项目依赖。")
152
+ return sql.SQL("SELECT * FROM {}.{} LIMIT {}").format(
153
+ sql.Identifier(schema),
154
+ sql.Identifier(table_name),
155
+ sql.Literal(sentinel_limit),
156
+ )
157
+
158
+ def build_explain_query(self, sql_text: str, analyze: bool = False) -> str:
159
+ return f"EXPLAIN (FORMAT JSON, ANALYZE {'TRUE' if analyze else 'FALSE'}) {sql_text}"
160
+
161
+ def extract_plan(self, rows):
162
+ return rows[0].get("QUERY PLAN", []) if rows else []
163
+
164
+ def column_names(self, description) -> List[str]:
165
+ return [column.name for column in (description or [])]
166
+
167
+ def _get_pool(self, connection_id: str, dsn: str) -> ConnectionPool:
168
+ if ConnectionPool is None or dict_row is None:
169
+ raise ConfigurationError("缺少 psycopg / psycopg-pool 依赖,请先安装项目依赖。")
170
+ pool = self._pools.get(connection_id)
171
+ if pool is None:
172
+ pool = ConnectionPool(
173
+ conninfo=dsn,
174
+ min_size=0,
175
+ max_size=4,
176
+ open=True,
177
+ kwargs={"autocommit": True, "row_factory": dict_row},
178
+ )
179
+ self._pools[connection_id] = pool
180
+ return pool
sql_query_mcp/app.py ADDED
@@ -0,0 +1,105 @@
1
+ """FastMCP application for stateless SQL queries."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+ from mcp.server.fastmcp import FastMCP
8
+
9
+ from .audit import AuditLogger
10
+ from .config import load_config
11
+ from .errors import SqlQueryMCPError
12
+ from .executor import QueryExecutor
13
+ from .introspection import MetadataService
14
+ from .registry import ConnectionRegistry
15
+
16
+
17
+ def create_app() -> FastMCP:
18
+ app_config = load_config()
19
+ registry = ConnectionRegistry(app_config)
20
+ audit_logger = AuditLogger(app_config.settings.audit_log_path)
21
+ metadata = MetadataService(registry, app_config.settings, audit_logger)
22
+ executor = QueryExecutor(registry, app_config.settings, audit_logger)
23
+
24
+ mcp = FastMCP("sql-query-mcp", json_response=True)
25
+
26
+ @mcp.tool()
27
+ def list_connections() -> dict:
28
+ """List configured SQL connections by connection_id."""
29
+
30
+ return {"connections": registry.list_connections()}
31
+
32
+ @mcp.tool()
33
+ def list_schemas(connection_id: str) -> dict:
34
+ """List visible schemas for a PostgreSQL connection."""
35
+
36
+ return _run_tool(lambda: metadata.list_schemas(connection_id))
37
+
38
+ @mcp.tool()
39
+ def list_databases(connection_id: str) -> dict:
40
+ """List visible databases for a MySQL connection."""
41
+
42
+ return _run_tool(lambda: metadata.list_databases(connection_id))
43
+
44
+ @mcp.tool()
45
+ def list_tables(
46
+ connection_id: str,
47
+ schema: Optional[str] = None,
48
+ database: Optional[str] = None,
49
+ ) -> dict:
50
+ """List tables and views for a resolved PostgreSQL schema or MySQL database."""
51
+
52
+ return _run_tool(lambda: metadata.list_tables(connection_id, schema, database))
53
+
54
+ @mcp.tool()
55
+ def describe_table(
56
+ connection_id: str,
57
+ table_name: str,
58
+ schema: Optional[str] = None,
59
+ database: Optional[str] = None,
60
+ ) -> dict:
61
+ """Describe columns, keys, and indexes for a table."""
62
+
63
+ return _run_tool(lambda: metadata.describe_table(connection_id, table_name, schema, database))
64
+
65
+ @mcp.tool()
66
+ def run_select(connection_id: str, sql: str, limit: Optional[int] = None) -> dict:
67
+ """Run a read-only SELECT or CTE query."""
68
+
69
+ return _run_tool(lambda: executor.run_select(connection_id, sql, limit))
70
+
71
+ @mcp.tool()
72
+ def explain_query(connection_id: str, sql: str, analyze: bool = False) -> dict:
73
+ """Run EXPLAIN on a read-only SELECT or CTE query."""
74
+
75
+ return _run_tool(lambda: executor.explain_query(connection_id, sql, analyze))
76
+
77
+ @mcp.tool()
78
+ def get_table_sample(
79
+ connection_id: str,
80
+ table_name: str,
81
+ schema: Optional[str] = None,
82
+ database: Optional[str] = None,
83
+ limit: Optional[int] = None,
84
+ ) -> dict:
85
+ """Fetch a small sample from a table for schema discovery."""
86
+
87
+ return _run_tool(lambda: executor.get_table_sample(connection_id, table_name, schema, database, limit))
88
+
89
+ return mcp
90
+
91
+
92
+ def _run_tool(func):
93
+ try:
94
+ return func()
95
+ except SqlQueryMCPError as exc:
96
+ raise ValueError(str(exc)) from exc
97
+
98
+
99
+ def main() -> None:
100
+ app = create_app()
101
+ app.run()
102
+
103
+
104
+ if __name__ == "__main__":
105
+ main()
sql_query_mcp/audit.py ADDED
@@ -0,0 +1,42 @@
1
+ """Audit logging utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from datetime import datetime, timezone
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
+
11
+ class AuditLogger:
12
+ """Write audit records as JSON lines."""
13
+
14
+ def __init__(self, log_path: Path):
15
+ self._log_path = Path(log_path)
16
+
17
+ def log(
18
+ self,
19
+ *,
20
+ tool: str,
21
+ connection_id: Optional[str],
22
+ success: bool,
23
+ duration_ms: int,
24
+ row_count: Optional[int] = None,
25
+ sql_summary: Optional[str] = None,
26
+ error: Optional[str] = None,
27
+ extra: Optional[Dict[str, Any]] = None,
28
+ ) -> None:
29
+ record = {
30
+ "timestamp": datetime.now(timezone.utc).isoformat(),
31
+ "tool": tool,
32
+ "connection_id": connection_id,
33
+ "success": success,
34
+ "duration_ms": duration_ms,
35
+ "row_count": row_count,
36
+ "sql_summary": sql_summary,
37
+ "error": error,
38
+ "extra": extra or {},
39
+ }
40
+ self._log_path.parent.mkdir(parents=True, exist_ok=True)
41
+ with self._log_path.open("a", encoding="utf-8") as handle:
42
+ handle.write(json.dumps(record, ensure_ascii=False) + "\n")