sql-query-mcp 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,255 @@
1
+ """Configuration loading for sql-query-mcp."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import re
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Iterable, List, Optional
11
+
12
+ from .errors import ConfigurationError
13
+
14
+ PACKAGE_ROOT = Path(__file__).resolve().parent.parent
15
+ DEFAULT_CONFIG_ENV = "SQL_QUERY_MCP_CONFIG"
16
+ DEFAULT_CONFIG_PATH = PACKAGE_ROOT / "config" / "connections.json"
17
+ DEFAULT_AUDIT_LOG_PATH = PACKAGE_ROOT / "logs" / "audit.jsonl"
18
+ CONNECTION_ID_RE = re.compile(r"^[a-z0-9]+(?:_[a-z0-9]+){3,}$")
19
+ SUPPORTED_ENGINES = {"postgres", "mysql"}
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class ServerSettings:
24
+ default_limit: int = 200
25
+ max_limit: int = 1000
26
+ statement_timeout_ms: Optional[int] = None
27
+ audit_log_path: Path = DEFAULT_AUDIT_LOG_PATH
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class ConnectionConfig:
32
+ connection_id: str
33
+ engine: str
34
+ env: str
35
+ tenant: str
36
+ role: str
37
+ dsn_env: str
38
+ enabled: bool = True
39
+ label: Optional[str] = None
40
+ description: Optional[str] = None
41
+ default_schema: Optional[str] = None
42
+ default_database: Optional[str] = None
43
+
44
+ @property
45
+ def summary(self) -> Dict[str, object]:
46
+ return {
47
+ "connection_id": self.connection_id,
48
+ "engine": self.engine,
49
+ "label": self.label or self.connection_id,
50
+ "env": self.env,
51
+ "tenant": self.tenant,
52
+ "role": self.role,
53
+ "enabled": self.enabled,
54
+ "default_schema": self.default_schema,
55
+ "default_database": self.default_database,
56
+ "description": self.description,
57
+ }
58
+
59
+
60
+ @dataclass(frozen=True)
61
+ class AppConfig:
62
+ settings: ServerSettings
63
+ connections: List[ConnectionConfig]
64
+
65
+ @property
66
+ def connection_map(self) -> Dict[str, ConnectionConfig]:
67
+ return {item.connection_id: item for item in self.connections}
68
+
69
+ def enabled_connections(self) -> Iterable[ConnectionConfig]:
70
+ return (item for item in self.connections if item.enabled)
71
+
72
+
73
+ def resolve_config_path(config_path: Optional[str] = None) -> Path:
74
+ raw_path = config_path or os.environ.get(DEFAULT_CONFIG_ENV)
75
+ return Path(raw_path).expanduser().resolve() if raw_path else DEFAULT_CONFIG_PATH
76
+
77
+
78
+ def load_config(config_path: Optional[str] = None) -> AppConfig:
79
+ path = resolve_config_path(config_path)
80
+ if not path.exists():
81
+ return AppConfig(settings=ServerSettings(), connections=[])
82
+
83
+ try:
84
+ payload = json.loads(path.read_text(encoding="utf-8"))
85
+ except json.JSONDecodeError as exc:
86
+ raise ConfigurationError(f"配置文件不是有效 JSON: {path}") from exc
87
+
88
+ settings = _parse_settings(payload.get("settings", {}), path)
89
+ connections = _parse_connections(payload.get("connections", []))
90
+ return AppConfig(settings=settings, connections=connections)
91
+
92
+
93
+ def _parse_settings(data: Dict[str, object], path: Path) -> ServerSettings:
94
+ default_limit = _required_int(data.get("default_limit", 200), "default_limit")
95
+ max_limit = _required_int(data.get("max_limit", 1000), "max_limit")
96
+ statement_timeout_ms = _optional_positive_int(
97
+ data.get("statement_timeout_ms"), "statement_timeout_ms"
98
+ )
99
+ audit_log_raw = str(
100
+ data.get("audit_log_path", DEFAULT_AUDIT_LOG_PATH.relative_to(PACKAGE_ROOT))
101
+ )
102
+ audit_log_path = Path(audit_log_raw)
103
+ if not audit_log_path.is_absolute():
104
+ audit_log_path = (path.parent / audit_log_path).resolve()
105
+
106
+ if default_limit <= 0:
107
+ raise ConfigurationError("default_limit 必须大于 0")
108
+ if max_limit < default_limit:
109
+ raise ConfigurationError("max_limit 不能小于 default_limit")
110
+ return ServerSettings(
111
+ default_limit=default_limit,
112
+ max_limit=max_limit,
113
+ statement_timeout_ms=statement_timeout_ms,
114
+ audit_log_path=audit_log_path,
115
+ )
116
+
117
+
118
+ def _parse_connections(items: object) -> List[ConnectionConfig]:
119
+ if not isinstance(items, list):
120
+ raise ConfigurationError("connections 必须是数组")
121
+
122
+ result: List[ConnectionConfig] = []
123
+ seen = set()
124
+ for item in items:
125
+ if not isinstance(item, dict):
126
+ raise ConfigurationError("connections 数组中的每一项都必须是对象")
127
+
128
+ connection_id = str(item.get("connection_id", "")).strip()
129
+ engine = str(item.get("engine", "")).strip()
130
+ label = _required_string(item, "label", connection_id or "connection")
131
+ enabled = _required_bool(item, "enabled", connection_id or "connection")
132
+ if not CONNECTION_ID_RE.match(connection_id):
133
+ raise ConfigurationError(
134
+ "connection_id 必须符合 <system>_<env>_<tenant>_<role> 风格,且只包含小写字母、数字、下划线"
135
+ )
136
+ if engine not in SUPPORTED_ENGINES:
137
+ raise ConfigurationError(
138
+ f"{connection_id} 缺少合法 engine,必须是 postgres 或 mysql"
139
+ )
140
+ if connection_id in seen:
141
+ raise ConfigurationError(f"重复的 connection_id: {connection_id}")
142
+ seen.add(connection_id)
143
+
144
+ dsn_env = str(item.get("dsn_env", "")).strip()
145
+ env = str(item.get("env", "")).strip()
146
+ tenant = str(item.get("tenant", "")).strip()
147
+ role = str(item.get("role", "")).strip()
148
+ if not all((dsn_env, env, tenant, role)):
149
+ raise ConfigurationError(
150
+ f"{connection_id} 缺少必要字段,必须提供 env / tenant / role / dsn_env"
151
+ )
152
+
153
+ _reject_legacy_namespace_fields(item, connection_id)
154
+ default_schema = _optional_string(item.get("default_schema"))
155
+ default_database = _optional_string(item.get("default_database"))
156
+
157
+ if engine == "postgres" and "default_database" in item:
158
+ raise ConfigurationError(
159
+ f"{connection_id} 是 PostgreSQL 连接,不能配置 default_database"
160
+ )
161
+ if engine == "mysql" and "default_schema" in item:
162
+ raise ConfigurationError(
163
+ f"{connection_id} 是 MySQL 连接,不能配置 default_schema"
164
+ )
165
+
166
+ result.append(
167
+ ConnectionConfig(
168
+ connection_id=connection_id,
169
+ engine=engine,
170
+ label=label,
171
+ description=(
172
+ str(item["description"]).strip()
173
+ if item.get("description")
174
+ else None
175
+ ),
176
+ env=env,
177
+ tenant=tenant,
178
+ role=role,
179
+ dsn_env=dsn_env,
180
+ enabled=enabled,
181
+ default_schema=default_schema,
182
+ default_database=default_database,
183
+ )
184
+ )
185
+
186
+ return result
187
+
188
+
189
+ def _reject_legacy_namespace_fields(
190
+ item: Dict[str, object], connection_id: str
191
+ ) -> None:
192
+ if "default_namespace" in item:
193
+ raise ConfigurationError(
194
+ f"{connection_id} 仍在使用 default_namespace,请改为 default_schema 或 default_database"
195
+ )
196
+ if "default_schemas" in item:
197
+ raise ConfigurationError(
198
+ f"{connection_id} 仍在使用 default_schemas,请收敛为单值字段 default_schema"
199
+ )
200
+
201
+
202
+ def _optional_string(value: object) -> Optional[str]:
203
+ if value is None:
204
+ return None
205
+ text = str(value).strip()
206
+ return text or None
207
+
208
+
209
+ def _optional_positive_int(value: Any, field_name: str) -> Optional[int]:
210
+ if value is None:
211
+ return None
212
+ parsed = _required_int(value, field_name)
213
+ if parsed <= 0:
214
+ raise ConfigurationError(f"{field_name} 必须是大于 0 的整数")
215
+ return parsed
216
+
217
+
218
+ def _required_int(value: Any, field_name: str) -> int:
219
+ if isinstance(value, bool):
220
+ raise ConfigurationError(f"{field_name} 必须是整数")
221
+ if isinstance(value, int):
222
+ return value
223
+ if isinstance(value, float):
224
+ if value.is_integer():
225
+ return int(value)
226
+ raise ConfigurationError(f"{field_name} 必须是整数")
227
+ if isinstance(value, str):
228
+ try:
229
+ return int(value)
230
+ except ValueError as exc:
231
+ raise ConfigurationError(f"{field_name} 必须是整数") from exc
232
+ try:
233
+ return value.__int__()
234
+ except (TypeError, ValueError) as exc:
235
+ raise ConfigurationError(f"{field_name} 必须是整数") from exc
236
+
237
+
238
+ def _required_string(
239
+ item: Dict[str, object], field_name: str, connection_id: str
240
+ ) -> str:
241
+ value = _optional_string(item.get(field_name))
242
+ if value is None:
243
+ raise ConfigurationError(f"{connection_id} 缺少必要字段 {field_name}")
244
+ return value
245
+
246
+
247
+ def _required_bool(
248
+ item: Dict[str, object], field_name: str, connection_id: str
249
+ ) -> bool:
250
+ if field_name not in item:
251
+ raise ConfigurationError(f"{connection_id} 缺少必要字段 {field_name}")
252
+ value = item[field_name]
253
+ if not isinstance(value, bool):
254
+ raise ConfigurationError(f"{connection_id} 的 {field_name} 必须是布尔值")
255
+ return value
@@ -0,0 +1,34 @@
1
+ """Error types for sql-query-mcp."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+
7
+
8
+ class SqlQueryMCPError(Exception):
9
+ """Base error for this package."""
10
+
11
+
12
+ class ConfigurationError(SqlQueryMCPError):
13
+ """Raised when local configuration is invalid."""
14
+
15
+
16
+ class ConnectionNotFoundError(SqlQueryMCPError):
17
+ """Raised when the requested connection_id does not exist."""
18
+
19
+
20
+ class SecurityError(SqlQueryMCPError):
21
+ """Raised when SQL validation rejects a query."""
22
+
23
+
24
+ class QueryExecutionError(SqlQueryMCPError):
25
+ """Raised when the database execution layer fails."""
26
+
27
+
28
+ _DSN_CREDENTIALS_RE = re.compile(r"://([^:@/\s]+):([^@/\s]+)@")
29
+
30
+
31
+ def sanitize_error_message(message: str) -> str:
32
+ """Mask DSN credentials before surfacing an error to the model."""
33
+
34
+ return _DSN_CREDENTIALS_RE.sub(r"://\1:***@", message)
@@ -0,0 +1,243 @@
1
+ """Query execution helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from typing import Any, Dict, Optional
7
+
8
+ from .audit import AuditLogger
9
+ from .config import ServerSettings
10
+ from .errors import QueryExecutionError, sanitize_error_message
11
+ from .namespace import resolve_namespace
12
+ from .registry import ConnectionRegistry
13
+ from .validator import (
14
+ build_limited_query,
15
+ clamp_limit,
16
+ summarize_sql,
17
+ validate_select_sql,
18
+ )
19
+
20
+
21
+ class QueryExecutor:
22
+ """Execute validated read-only SQL."""
23
+
24
+ def __init__(
25
+ self,
26
+ registry: ConnectionRegistry,
27
+ settings: ServerSettings,
28
+ audit_logger: AuditLogger,
29
+ ):
30
+ self._registry = registry
31
+ self._settings = settings
32
+ self._audit = audit_logger
33
+
34
+ def run_select(
35
+ self, connection_id: str, sql_text: str, limit: Optional[int] = None
36
+ ) -> Dict[str, object]:
37
+ started = time.perf_counter()
38
+ config = None
39
+ row_limit = clamp_limit(
40
+ limit, self._settings.default_limit, self._settings.max_limit
41
+ )
42
+
43
+ try:
44
+ config = self._registry.get_connection_config(connection_id)
45
+ cleaned_sql = validate_select_sql(sql_text, config.engine)
46
+ limited_sql, _ = build_limited_query(cleaned_sql, row_limit)
47
+ sql_summary = summarize_sql(cleaned_sql)
48
+ with self._registry.connection_from_config(config) as (conn, adapter):
49
+ _apply_statement_timeout(
50
+ adapter, conn, self._settings.statement_timeout_ms
51
+ )
52
+ with conn.cursor() as cur:
53
+ cur.execute(limited_sql)
54
+ columns = adapter.column_names(cur.description)
55
+ rows = cur.fetchall()
56
+
57
+ truncated = len(rows) > row_limit
58
+ trimmed_rows = rows[:row_limit]
59
+ duration_ms = _elapsed_ms(started)
60
+ self._audit.log(
61
+ tool="run_select",
62
+ connection_id=connection_id,
63
+ success=True,
64
+ duration_ms=duration_ms,
65
+ row_count=len(trimmed_rows),
66
+ sql_summary=sql_summary,
67
+ extra={
68
+ "engine": config.engine,
69
+ "limit": row_limit,
70
+ "truncated": truncated,
71
+ },
72
+ )
73
+ return {
74
+ "connection_id": connection_id,
75
+ "engine": config.engine,
76
+ "columns": columns,
77
+ "rows": trimmed_rows,
78
+ "row_count": len(trimmed_rows),
79
+ "truncated": truncated,
80
+ "duration_ms": duration_ms,
81
+ "applied_limit": row_limit,
82
+ }
83
+ except Exception as exc:
84
+ duration_ms = _elapsed_ms(started)
85
+ sql_summary = summarize_sql(sql_text)
86
+ sanitized = sanitize_error_message(str(exc))
87
+ self._audit.log(
88
+ tool="run_select",
89
+ connection_id=connection_id,
90
+ success=False,
91
+ duration_ms=duration_ms,
92
+ sql_summary=sql_summary,
93
+ error=sanitized,
94
+ extra=_build_audit_extra(config, limit=row_limit),
95
+ )
96
+ raise QueryExecutionError(sanitized) from exc
97
+
98
+ def explain_query(
99
+ self, connection_id: str, sql_text: str, analyze: bool = False
100
+ ) -> Dict[str, object]:
101
+ started = time.perf_counter()
102
+ config = None
103
+ try:
104
+ config = self._registry.get_connection_config(connection_id)
105
+ cleaned_sql = validate_select_sql(sql_text, config.engine)
106
+ sql_summary = summarize_sql(cleaned_sql)
107
+ adapter = self._registry.get_adapter(config)
108
+ explain_sql = adapter.build_explain_query(cleaned_sql, analyze=analyze)
109
+ with self._registry.connection_from_config(config) as (conn, adapter):
110
+ _apply_statement_timeout(
111
+ adapter, conn, self._settings.statement_timeout_ms
112
+ )
113
+ with conn.cursor() as cur:
114
+ cur.execute(explain_sql)
115
+ rows = cur.fetchall()
116
+
117
+ duration_ms = _elapsed_ms(started)
118
+ plan = adapter.extract_plan(rows)
119
+ self._audit.log(
120
+ tool="explain_query",
121
+ connection_id=connection_id,
122
+ success=True,
123
+ duration_ms=duration_ms,
124
+ row_count=1 if rows else 0,
125
+ sql_summary=sql_summary,
126
+ extra={"engine": config.engine, "analyze": analyze},
127
+ )
128
+ return {
129
+ "connection_id": connection_id,
130
+ "engine": config.engine,
131
+ "plan": plan,
132
+ "duration_ms": duration_ms,
133
+ "analyze": analyze,
134
+ }
135
+ except Exception as exc:
136
+ duration_ms = _elapsed_ms(started)
137
+ sql_summary = summarize_sql(sql_text)
138
+ sanitized = sanitize_error_message(str(exc))
139
+ self._audit.log(
140
+ tool="explain_query",
141
+ connection_id=connection_id,
142
+ success=False,
143
+ duration_ms=duration_ms,
144
+ sql_summary=sql_summary,
145
+ error=sanitized,
146
+ extra=_build_audit_extra(config, analyze=analyze),
147
+ )
148
+ raise QueryExecutionError(sanitized) from exc
149
+
150
+ def get_table_sample(
151
+ self,
152
+ connection_id: str,
153
+ table_name: str,
154
+ schema: Optional[str] = None,
155
+ database: Optional[str] = None,
156
+ limit: Optional[int] = None,
157
+ ) -> Dict[str, object]:
158
+ row_limit = clamp_limit(
159
+ limit, self._settings.default_limit, self._settings.max_limit
160
+ )
161
+ started = time.perf_counter()
162
+ config = None
163
+ try:
164
+ config = self._registry.get_connection_config(connection_id)
165
+ namespace = resolve_namespace(config, schema=schema, database=database)
166
+ adapter = self._registry.get_adapter(config)
167
+ query = adapter.build_sample_query(
168
+ namespace.value, table_name, row_limit + 1
169
+ )
170
+ with self._registry.connection_from_config(config) as (conn, adapter):
171
+ _apply_statement_timeout(
172
+ adapter, conn, self._settings.statement_timeout_ms
173
+ )
174
+ with conn.cursor() as cur:
175
+ cur.execute(query)
176
+ columns = adapter.column_names(cur.description)
177
+ rows = cur.fetchall()
178
+
179
+ truncated = len(rows) > row_limit
180
+ trimmed_rows = rows[:row_limit]
181
+ duration_ms = _elapsed_ms(started)
182
+ self._audit.log(
183
+ tool="get_table_sample",
184
+ connection_id=connection_id,
185
+ success=True,
186
+ duration_ms=duration_ms,
187
+ row_count=len(trimmed_rows),
188
+ sql_summary=f"sample {namespace.value}.{table_name}",
189
+ extra={
190
+ "engine": config.engine,
191
+ namespace.field_name: namespace.value,
192
+ "table_name": table_name,
193
+ "limit": row_limit,
194
+ },
195
+ )
196
+ return {
197
+ "connection_id": connection_id,
198
+ "engine": config.engine,
199
+ namespace.field_name: namespace.value,
200
+ "table_name": table_name,
201
+ "columns": columns,
202
+ "rows": trimmed_rows,
203
+ "row_count": len(trimmed_rows),
204
+ "truncated": truncated,
205
+ "duration_ms": duration_ms,
206
+ "applied_limit": row_limit,
207
+ }
208
+ except Exception as exc:
209
+ duration_ms = _elapsed_ms(started)
210
+ sanitized = sanitize_error_message(str(exc))
211
+ self._audit.log(
212
+ tool="get_table_sample",
213
+ connection_id=connection_id,
214
+ success=False,
215
+ duration_ms=duration_ms,
216
+ error=sanitized,
217
+ extra=_build_audit_extra(
218
+ config,
219
+ schema=schema,
220
+ database=database,
221
+ table_name=table_name,
222
+ limit=row_limit,
223
+ ),
224
+ )
225
+ raise QueryExecutionError(sanitized) from exc
226
+
227
+
228
+ def _elapsed_ms(started: float) -> int:
229
+ return int((time.perf_counter() - started) * 1000)
230
+
231
+
232
+ def _apply_statement_timeout(
233
+ adapter: Any, conn: Any, timeout_ms: Optional[int]
234
+ ) -> None:
235
+ if timeout_ms is not None:
236
+ getattr(adapter, "set_statement_timeout")(conn, timeout_ms)
237
+
238
+
239
+ def _build_audit_extra(config, **kwargs: object) -> Dict[str, object]:
240
+ extra = dict(kwargs)
241
+ if config is not None:
242
+ extra["engine"] = config.engine
243
+ return extra