sql-query-mcp 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,225 @@
1
+ """Metadata query helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from typing import Any, Dict, Optional
7
+
8
+ from .audit import AuditLogger
9
+ from .config import ServerSettings
10
+ from .errors import QueryExecutionError, sanitize_error_message
11
+ from .namespace import require_engine, resolve_namespace
12
+ from .registry import ConnectionRegistry
13
+
14
+
15
+ class MetadataService:
16
+ """Read-only metadata queries across PostgreSQL and MySQL."""
17
+
18
+ def __init__(
19
+ self,
20
+ registry: ConnectionRegistry,
21
+ settings: ServerSettings,
22
+ audit_logger: AuditLogger,
23
+ ):
24
+ self._registry = registry
25
+ self._settings = settings
26
+ self._audit = audit_logger
27
+
28
+ def list_schemas(self, connection_id: str) -> Dict[str, object]:
29
+ started = time.perf_counter()
30
+ config = None
31
+ try:
32
+ config = self._registry.get_connection_config(connection_id)
33
+ require_engine(config, "postgres", "list_schemas")
34
+ with self._registry.connection_from_config(config) as (conn, adapter):
35
+ _apply_statement_timeout(
36
+ adapter, conn, self._settings.statement_timeout_ms
37
+ )
38
+ schemas = adapter.list_schemas(conn)
39
+ duration_ms = _elapsed_ms(started)
40
+ self._audit.log(
41
+ tool="list_schemas",
42
+ connection_id=connection_id,
43
+ success=True,
44
+ duration_ms=duration_ms,
45
+ row_count=len(schemas),
46
+ extra={"engine": config.engine},
47
+ )
48
+ return {
49
+ "connection_id": connection_id,
50
+ "engine": "postgres",
51
+ "schemas": schemas,
52
+ }
53
+ except Exception as exc:
54
+ duration_ms = _elapsed_ms(started)
55
+ sanitized = sanitize_error_message(str(exc))
56
+ self._audit.log(
57
+ tool="list_schemas",
58
+ connection_id=connection_id,
59
+ success=False,
60
+ duration_ms=duration_ms,
61
+ error=sanitized,
62
+ extra=_build_audit_extra(config),
63
+ )
64
+ raise QueryExecutionError(sanitized) from exc
65
+
66
+ def list_databases(self, connection_id: str) -> Dict[str, object]:
67
+ started = time.perf_counter()
68
+ config = None
69
+ try:
70
+ config = self._registry.get_connection_config(connection_id)
71
+ require_engine(config, "mysql", "list_databases")
72
+ with self._registry.connection_from_config(config) as (conn, adapter):
73
+ _apply_statement_timeout(
74
+ adapter, conn, self._settings.statement_timeout_ms
75
+ )
76
+ databases = adapter.list_databases(conn)
77
+ duration_ms = _elapsed_ms(started)
78
+ self._audit.log(
79
+ tool="list_databases",
80
+ connection_id=connection_id,
81
+ success=True,
82
+ duration_ms=duration_ms,
83
+ row_count=len(databases),
84
+ extra={"engine": config.engine},
85
+ )
86
+ return {
87
+ "connection_id": connection_id,
88
+ "engine": "mysql",
89
+ "databases": databases,
90
+ }
91
+ except Exception as exc:
92
+ duration_ms = _elapsed_ms(started)
93
+ sanitized = sanitize_error_message(str(exc))
94
+ self._audit.log(
95
+ tool="list_databases",
96
+ connection_id=connection_id,
97
+ success=False,
98
+ duration_ms=duration_ms,
99
+ error=sanitized,
100
+ extra=_build_audit_extra(config),
101
+ )
102
+ raise QueryExecutionError(sanitized) from exc
103
+
104
+ def list_tables(
105
+ self,
106
+ connection_id: str,
107
+ schema: Optional[str] = None,
108
+ database: Optional[str] = None,
109
+ ) -> Dict[str, object]:
110
+ started = time.perf_counter()
111
+ config = None
112
+ try:
113
+ config = self._registry.get_connection_config(connection_id)
114
+ namespace = resolve_namespace(config, schema=schema, database=database)
115
+ with self._registry.connection_from_config(config) as (conn, adapter):
116
+ _apply_statement_timeout(
117
+ adapter, conn, self._settings.statement_timeout_ms
118
+ )
119
+ tables = adapter.list_tables(conn, namespace.value)
120
+ duration_ms = _elapsed_ms(started)
121
+ self._audit.log(
122
+ tool="list_tables",
123
+ connection_id=connection_id,
124
+ success=True,
125
+ duration_ms=duration_ms,
126
+ row_count=len(tables),
127
+ extra={"engine": config.engine, namespace.field_name: namespace.value},
128
+ )
129
+ return {
130
+ "connection_id": connection_id,
131
+ "engine": config.engine,
132
+ namespace.field_name: namespace.value,
133
+ "tables": tables,
134
+ }
135
+ except Exception as exc:
136
+ duration_ms = _elapsed_ms(started)
137
+ sanitized = sanitize_error_message(str(exc))
138
+ self._audit.log(
139
+ tool="list_tables",
140
+ connection_id=connection_id,
141
+ success=False,
142
+ duration_ms=duration_ms,
143
+ error=sanitized,
144
+ extra=_build_audit_extra(config, schema=schema, database=database),
145
+ )
146
+ raise QueryExecutionError(sanitized) from exc
147
+
148
+ def describe_table(
149
+ self,
150
+ connection_id: str,
151
+ table_name: str,
152
+ schema: Optional[str] = None,
153
+ database: Optional[str] = None,
154
+ ) -> Dict[str, object]:
155
+ started = time.perf_counter()
156
+ config = None
157
+ try:
158
+ config = self._registry.get_connection_config(connection_id)
159
+ namespace = resolve_namespace(config, schema=schema, database=database)
160
+ with self._registry.connection_from_config(config) as (conn, adapter):
161
+ _apply_statement_timeout(
162
+ adapter, conn, self._settings.statement_timeout_ms
163
+ )
164
+ description = adapter.describe_table(conn, namespace.value, table_name)
165
+ if not description:
166
+ raise QueryExecutionError(
167
+ f"未找到表 {namespace.value}.{table_name},或当前用户没有访问权限"
168
+ )
169
+
170
+ duration_ms = _elapsed_ms(started)
171
+ self._audit.log(
172
+ tool="describe_table",
173
+ connection_id=connection_id,
174
+ success=True,
175
+ duration_ms=duration_ms,
176
+ row_count=len(description["columns"]),
177
+ extra={
178
+ "engine": config.engine,
179
+ namespace.field_name: namespace.value,
180
+ "table_name": table_name,
181
+ },
182
+ )
183
+ return {
184
+ "connection_id": connection_id,
185
+ "engine": config.engine,
186
+ namespace.field_name: namespace.value,
187
+ "table_name": table_name,
188
+ "columns": description["columns"],
189
+ "indexes": description["indexes"],
190
+ }
191
+ except Exception as exc:
192
+ duration_ms = _elapsed_ms(started)
193
+ sanitized = sanitize_error_message(str(exc))
194
+ self._audit.log(
195
+ tool="describe_table",
196
+ connection_id=connection_id,
197
+ success=False,
198
+ duration_ms=duration_ms,
199
+ error=sanitized,
200
+ extra=_build_audit_extra(
201
+ config,
202
+ schema=schema,
203
+ database=database,
204
+ table_name=table_name,
205
+ ),
206
+ )
207
+ raise QueryExecutionError(sanitized) from exc
208
+
209
+
210
+ def _elapsed_ms(started: float) -> int:
211
+ return int((time.perf_counter() - started) * 1000)
212
+
213
+
214
+ def _apply_statement_timeout(
215
+ adapter: Any, conn: Any, timeout_ms: Optional[int]
216
+ ) -> None:
217
+ if timeout_ms is not None:
218
+ getattr(adapter, "set_statement_timeout")(conn, timeout_ms)
219
+
220
+
221
+ def _build_audit_extra(config, **kwargs: object) -> Dict[str, object]:
222
+ extra = dict(kwargs)
223
+ if config is not None:
224
+ extra["engine"] = config.engine
225
+ return extra
@@ -0,0 +1,48 @@
1
+ """Namespace resolution helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ from .config import ConnectionConfig
9
+ from .errors import SecurityError
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class NamespaceSelection:
14
+ field_name: str
15
+ value: str
16
+
17
+
18
+ def resolve_namespace(
19
+ config: ConnectionConfig,
20
+ *,
21
+ schema: Optional[str] = None,
22
+ database: Optional[str] = None,
23
+ ) -> NamespaceSelection:
24
+ if schema and database:
25
+ raise SecurityError("schema 和 database 不能同时传入。")
26
+
27
+ if config.engine == "postgres":
28
+ if database:
29
+ raise SecurityError("PostgreSQL 连接不接受 database 参数。")
30
+ resolved = schema or config.default_schema
31
+ if not resolved:
32
+ raise SecurityError("PostgreSQL 连接必须显式传 schema,或在配置中设置 default_schema。")
33
+ return NamespaceSelection(field_name="schema", value=resolved)
34
+
35
+ if config.engine == "mysql":
36
+ if schema:
37
+ raise SecurityError("MySQL 连接不接受 schema 参数。")
38
+ resolved = database or config.default_database
39
+ if not resolved:
40
+ raise SecurityError("MySQL 连接必须显式传 database,或在配置中设置 default_database。")
41
+ return NamespaceSelection(field_name="database", value=resolved)
42
+
43
+ raise SecurityError(f"未知 engine: {config.engine}")
44
+
45
+
46
+ def require_engine(config: ConnectionConfig, engine: str, tool_name: str) -> None:
47
+ if config.engine != engine:
48
+ raise SecurityError(f"{tool_name} 仅适用于 {engine} 连接,当前连接 engine={config.engine}")
@@ -0,0 +1,67 @@
1
+ """Connection registry and adapter routing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from contextlib import contextmanager
7
+ from typing import Any, Iterator, Tuple
8
+
9
+ from .adapters import MySQLAdapter, PostgresAdapter
10
+ from .config import AppConfig, ConnectionConfig
11
+ from .errors import ConfigurationError, ConnectionNotFoundError
12
+
13
+
14
+ class ConnectionRegistry:
15
+ """Resolve connection config and route to the correct engine adapter."""
16
+
17
+ def __init__(self, app_config: AppConfig):
18
+ self._config = app_config
19
+ self._adapters = {
20
+ "postgres": PostgresAdapter(),
21
+ "mysql": MySQLAdapter(),
22
+ }
23
+
24
+ def list_connections(self):
25
+ return [item.summary for item in self._config.connections]
26
+
27
+ def get_connection_config(self, connection_id: str) -> ConnectionConfig:
28
+ try:
29
+ config = self._config.connection_map[connection_id]
30
+ except KeyError as exc:
31
+ raise ConnectionNotFoundError(
32
+ f"未知 connection_id: {connection_id}"
33
+ ) from exc
34
+ if not config.enabled:
35
+ raise ConnectionNotFoundError(f"connection_id 已被禁用: {connection_id}")
36
+ return config
37
+
38
+ @contextmanager
39
+ def connection(
40
+ self, connection_id: str
41
+ ) -> Iterator[Tuple[Any, ConnectionConfig, Any]]:
42
+ config = self.get_connection_config(connection_id)
43
+ with self.connection_from_config(config) as (conn, adapter):
44
+ yield conn, config, adapter
45
+
46
+ @contextmanager
47
+ def connection_from_config(
48
+ self, config: ConnectionConfig
49
+ ) -> Iterator[Tuple[Any, Any]]:
50
+ dsn = os.environ.get(config.dsn_env)
51
+ if not dsn:
52
+ raise ConfigurationError(
53
+ f"{config.connection_id} 缺少环境变量 {config.dsn_env},无法建立数据库连接"
54
+ )
55
+ adapter = self.get_adapter(config)
56
+ with adapter.connection(config.connection_id, dsn) as conn:
57
+ yield conn, adapter
58
+
59
+ def close(self) -> None:
60
+ for adapter in self._adapters.values():
61
+ adapter.close()
62
+
63
+ def get_adapter(self, config: ConnectionConfig):
64
+ try:
65
+ return self._adapters[config.engine]
66
+ except KeyError as exc:
67
+ raise ConfigurationError(f"不支持的 engine: {config.engine}") from exc
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ import re
7
+
8
+ try:
9
+ import tomllib
10
+ except ModuleNotFoundError: # pragma: no cover - Python 3.10 fallback
11
+ import tomli as tomllib # type: ignore[import-not-found]
12
+
13
+
14
+ TAG_PATTERN = re.compile(r"^v\d+\.\d+\.\d+$")
15
+ VERSION_PATTERN = re.compile(r"^\d+\.\d+\.\d+$")
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class ReleaseContext:
20
+ tag: str
21
+ version: str
22
+ release_branch: str
23
+
24
+
25
+ def resolve_effective_tag(
26
+ event_name: str, github_ref_name: str, input_tag: str | None
27
+ ) -> str:
28
+ if event_name == "workflow_dispatch":
29
+ if not input_tag:
30
+ raise ValueError("workflow_dispatch requires an explicit tag")
31
+ return input_tag
32
+ return github_ref_name
33
+
34
+
35
+ def parse_version_tag(tag: str) -> str:
36
+ if not TAG_PATTERN.match(tag):
37
+ raise ValueError("release tags must match vX.Y.Z")
38
+ return tag[1:]
39
+
40
+
41
+ def read_project_version(pyproject_path: Path) -> str:
42
+ data = tomllib.loads(pyproject_path.read_text(encoding="utf-8"))
43
+ version = data.get("project", {}).get("version")
44
+ if not version:
45
+ raise ValueError("missing project.version")
46
+ if not VERSION_PATTERN.match(version):
47
+ raise ValueError("project.version must match X.Y.Z")
48
+ return version
49
+
50
+
51
+ def build_release_context(tag: str, pyproject_path: Path) -> ReleaseContext:
52
+ version = parse_version_tag(tag)
53
+ project_version = read_project_version(pyproject_path)
54
+ if version != project_version:
55
+ raise ValueError("tag version does not match pyproject version")
56
+ return ReleaseContext(tag=tag, version=version, release_branch=f"release/{tag}")
57
+
58
+
59
+ def should_skip_pypi_upload(
60
+ is_recovery_run: bool,
61
+ pypi_version_exists: bool,
62
+ recovery_confirmed: bool,
63
+ ) -> bool:
64
+ return is_recovery_run and pypi_version_exists and recovery_confirmed
65
+
66
+
67
+ def decide_backmerge_action(target: str, has_open_pr: bool, has_diff: bool) -> str:
68
+ if has_open_pr:
69
+ return "reuse"
70
+ if target == "main":
71
+ return "create"
72
+ if has_diff:
73
+ return "create"
74
+ return "skip"
75
+
76
+
77
+ def _build_parser() -> argparse.ArgumentParser:
78
+ parser = argparse.ArgumentParser(description="Resolve release metadata.")
79
+ parser.add_argument("--tag", required=True)
80
+ parser.add_argument("--pyproject", required=True)
81
+ return parser
82
+
83
+
84
+ def main() -> None:
85
+ args = _build_parser().parse_args()
86
+ context = build_release_context(args.tag, Path(args.pyproject))
87
+ print(f"tag={context.tag}")
88
+ print(f"version={context.version}")
89
+ print(f"release_branch={context.release_branch}")
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
@@ -0,0 +1,128 @@
1
+ """SQL validation and normalization."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from typing import Any, Optional, Tuple
7
+
8
+ from sqlglot import exp, parse_one
9
+ from sqlglot.errors import ParseError
10
+
11
+ from .errors import SecurityError
12
+
13
+ COMMENT_TOKENS = ("--", "/*", "*/")
14
+ READ_ONLY_ROOT_TYPES = tuple(
15
+ expr_type
16
+ for expr_type in (
17
+ getattr(exp, name, None) for name in ("Select", "Union", "Except", "Intersect")
18
+ )
19
+ if isinstance(expr_type, type)
20
+ )
21
+ MUTATING_EXPRESSION_TYPES = tuple(
22
+ expr_type
23
+ for expr_type in (
24
+ getattr(exp, name, None)
25
+ for name in (
26
+ "Insert",
27
+ "Update",
28
+ "Delete",
29
+ "Merge",
30
+ "Create",
31
+ "Drop",
32
+ "Alter",
33
+ "Command",
34
+ "Copy",
35
+ "Call",
36
+ "Set",
37
+ "Pragma",
38
+ "Use",
39
+ "Grant",
40
+ "Revoke",
41
+ "Transaction",
42
+ "Commit",
43
+ "Rollback",
44
+ "TruncateTable",
45
+ "Vacuum",
46
+ "Into",
47
+ )
48
+ )
49
+ if isinstance(expr_type, type)
50
+ )
51
+ DIALECT_BY_ENGINE = {"postgres": "postgres", "mysql": "mysql"}
52
+
53
+
54
+ def validate_select_sql(sql: str, engine: str) -> str:
55
+ cleaned = _clean_sql(sql)
56
+ lowered = cleaned.lstrip().lower()
57
+ if lowered.startswith("explain"):
58
+ raise SecurityError(
59
+ "explain_query 会自动包装 SQL,请直接传 SELECT 或 WITH 查询。"
60
+ )
61
+ if not (lowered.startswith("select") or lowered.startswith("with")):
62
+ raise SecurityError("仅允许 SELECT 或 WITH ... SELECT 语句。")
63
+ statement = _parse_statement(cleaned, engine)
64
+ _ensure_read_only_statement(statement)
65
+ return cleaned
66
+
67
+
68
+ def clamp_limit(limit: Optional[int], default_limit: int, max_limit: int) -> int:
69
+ value = default_limit if limit is None else int(limit)
70
+ if value <= 0:
71
+ raise SecurityError("limit 必须大于 0。")
72
+ return min(value, max_limit)
73
+
74
+
75
+ def build_limited_query(sql: str, row_limit: int) -> Tuple[str, int]:
76
+ sentinel_limit = row_limit + 1
77
+ wrapped = f"SELECT * FROM ({sql}) AS _pq_result LIMIT {sentinel_limit}"
78
+ return wrapped, sentinel_limit
79
+
80
+
81
+ def summarize_sql(sql: str, max_chars: int = 160) -> str:
82
+ one_line = re.sub(r"\s+", " ", sql).strip()
83
+ if len(one_line) <= max_chars:
84
+ return one_line
85
+ return one_line[: max_chars - 3] + "..."
86
+
87
+
88
+ def _clean_sql(sql: str) -> str:
89
+ if not sql or not sql.strip():
90
+ raise SecurityError("SQL 不能为空。")
91
+
92
+ cleaned = sql.strip()
93
+ for token in COMMENT_TOKENS:
94
+ if token in cleaned:
95
+ raise SecurityError("不允许使用 SQL 注释。")
96
+
97
+ semicolon_count = cleaned.count(";")
98
+ if semicolon_count > 1:
99
+ raise SecurityError("只允许单条 SQL 语句。")
100
+ if semicolon_count == 1 and not cleaned.endswith(";"):
101
+ raise SecurityError("只允许单条 SQL 语句。")
102
+ if cleaned.endswith(";"):
103
+ cleaned = cleaned[:-1].rstrip()
104
+
105
+ if not cleaned:
106
+ raise SecurityError("SQL 不能为空。")
107
+ return cleaned
108
+
109
+
110
+ def _parse_statement(sql: str, engine: str) -> Any:
111
+ try:
112
+ dialect = DIALECT_BY_ENGINE[engine]
113
+ except KeyError as exc:
114
+ raise SecurityError(f"不支持的 SQL 方言: {engine}") from exc
115
+
116
+ try:
117
+ return parse_one(sql, dialect=dialect)
118
+ except ParseError as exc:
119
+ raise SecurityError(f"SQL 解析失败,已拒绝执行: {exc}") from exc
120
+
121
+
122
+ def _ensure_read_only_statement(statement: Any) -> None:
123
+ if not isinstance(statement, READ_ONLY_ROOT_TYPES):
124
+ raise SecurityError("仅允许 SELECT 或 WITH ... SELECT 语句。")
125
+
126
+ for node in statement.walk():
127
+ if isinstance(node, MUTATING_EXPRESSION_TYPES):
128
+ raise SecurityError(f"仅允许只读查询,检测到写操作: {node.key.upper()}")