sql-query-mcp 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_query_mcp/__init__.py +5 -0
- sql_query_mcp/__main__.py +5 -0
- sql_query_mcp/adapters/__init__.py +15 -0
- sql_query_mcp/adapters/mysql.py +171 -0
- sql_query_mcp/adapters/postgres.py +180 -0
- sql_query_mcp/app.py +105 -0
- sql_query_mcp/audit.py +42 -0
- sql_query_mcp/config.py +255 -0
- sql_query_mcp/errors.py +34 -0
- sql_query_mcp/executor.py +243 -0
- sql_query_mcp/introspection.py +225 -0
- sql_query_mcp/namespace.py +48 -0
- sql_query_mcp/registry.py +67 -0
- sql_query_mcp/release_metadata.py +93 -0
- sql_query_mcp/validator.py +128 -0
- sql_query_mcp-0.1.1.dist-info/METADATA +235 -0
- sql_query_mcp-0.1.1.dist-info/RECORD +21 -0
- sql_query_mcp-0.1.1.dist-info/WHEEL +5 -0
- sql_query_mcp-0.1.1.dist-info/entry_points.txt +2 -0
- sql_query_mcp-0.1.1.dist-info/licenses/LICENSE +21 -0
- sql_query_mcp-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""Metadata query helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
|
+
|
|
8
|
+
from .audit import AuditLogger
|
|
9
|
+
from .config import ServerSettings
|
|
10
|
+
from .errors import QueryExecutionError, sanitize_error_message
|
|
11
|
+
from .namespace import require_engine, resolve_namespace
|
|
12
|
+
from .registry import ConnectionRegistry
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MetadataService:
|
|
16
|
+
"""Read-only metadata queries across PostgreSQL and MySQL."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
registry: ConnectionRegistry,
|
|
21
|
+
settings: ServerSettings,
|
|
22
|
+
audit_logger: AuditLogger,
|
|
23
|
+
):
|
|
24
|
+
self._registry = registry
|
|
25
|
+
self._settings = settings
|
|
26
|
+
self._audit = audit_logger
|
|
27
|
+
|
|
28
|
+
def list_schemas(self, connection_id: str) -> Dict[str, object]:
|
|
29
|
+
started = time.perf_counter()
|
|
30
|
+
config = None
|
|
31
|
+
try:
|
|
32
|
+
config = self._registry.get_connection_config(connection_id)
|
|
33
|
+
require_engine(config, "postgres", "list_schemas")
|
|
34
|
+
with self._registry.connection_from_config(config) as (conn, adapter):
|
|
35
|
+
_apply_statement_timeout(
|
|
36
|
+
adapter, conn, self._settings.statement_timeout_ms
|
|
37
|
+
)
|
|
38
|
+
schemas = adapter.list_schemas(conn)
|
|
39
|
+
duration_ms = _elapsed_ms(started)
|
|
40
|
+
self._audit.log(
|
|
41
|
+
tool="list_schemas",
|
|
42
|
+
connection_id=connection_id,
|
|
43
|
+
success=True,
|
|
44
|
+
duration_ms=duration_ms,
|
|
45
|
+
row_count=len(schemas),
|
|
46
|
+
extra={"engine": config.engine},
|
|
47
|
+
)
|
|
48
|
+
return {
|
|
49
|
+
"connection_id": connection_id,
|
|
50
|
+
"engine": "postgres",
|
|
51
|
+
"schemas": schemas,
|
|
52
|
+
}
|
|
53
|
+
except Exception as exc:
|
|
54
|
+
duration_ms = _elapsed_ms(started)
|
|
55
|
+
sanitized = sanitize_error_message(str(exc))
|
|
56
|
+
self._audit.log(
|
|
57
|
+
tool="list_schemas",
|
|
58
|
+
connection_id=connection_id,
|
|
59
|
+
success=False,
|
|
60
|
+
duration_ms=duration_ms,
|
|
61
|
+
error=sanitized,
|
|
62
|
+
extra=_build_audit_extra(config),
|
|
63
|
+
)
|
|
64
|
+
raise QueryExecutionError(sanitized) from exc
|
|
65
|
+
|
|
66
|
+
def list_databases(self, connection_id: str) -> Dict[str, object]:
|
|
67
|
+
started = time.perf_counter()
|
|
68
|
+
config = None
|
|
69
|
+
try:
|
|
70
|
+
config = self._registry.get_connection_config(connection_id)
|
|
71
|
+
require_engine(config, "mysql", "list_databases")
|
|
72
|
+
with self._registry.connection_from_config(config) as (conn, adapter):
|
|
73
|
+
_apply_statement_timeout(
|
|
74
|
+
adapter, conn, self._settings.statement_timeout_ms
|
|
75
|
+
)
|
|
76
|
+
databases = adapter.list_databases(conn)
|
|
77
|
+
duration_ms = _elapsed_ms(started)
|
|
78
|
+
self._audit.log(
|
|
79
|
+
tool="list_databases",
|
|
80
|
+
connection_id=connection_id,
|
|
81
|
+
success=True,
|
|
82
|
+
duration_ms=duration_ms,
|
|
83
|
+
row_count=len(databases),
|
|
84
|
+
extra={"engine": config.engine},
|
|
85
|
+
)
|
|
86
|
+
return {
|
|
87
|
+
"connection_id": connection_id,
|
|
88
|
+
"engine": "mysql",
|
|
89
|
+
"databases": databases,
|
|
90
|
+
}
|
|
91
|
+
except Exception as exc:
|
|
92
|
+
duration_ms = _elapsed_ms(started)
|
|
93
|
+
sanitized = sanitize_error_message(str(exc))
|
|
94
|
+
self._audit.log(
|
|
95
|
+
tool="list_databases",
|
|
96
|
+
connection_id=connection_id,
|
|
97
|
+
success=False,
|
|
98
|
+
duration_ms=duration_ms,
|
|
99
|
+
error=sanitized,
|
|
100
|
+
extra=_build_audit_extra(config),
|
|
101
|
+
)
|
|
102
|
+
raise QueryExecutionError(sanitized) from exc
|
|
103
|
+
|
|
104
|
+
def list_tables(
|
|
105
|
+
self,
|
|
106
|
+
connection_id: str,
|
|
107
|
+
schema: Optional[str] = None,
|
|
108
|
+
database: Optional[str] = None,
|
|
109
|
+
) -> Dict[str, object]:
|
|
110
|
+
started = time.perf_counter()
|
|
111
|
+
config = None
|
|
112
|
+
try:
|
|
113
|
+
config = self._registry.get_connection_config(connection_id)
|
|
114
|
+
namespace = resolve_namespace(config, schema=schema, database=database)
|
|
115
|
+
with self._registry.connection_from_config(config) as (conn, adapter):
|
|
116
|
+
_apply_statement_timeout(
|
|
117
|
+
adapter, conn, self._settings.statement_timeout_ms
|
|
118
|
+
)
|
|
119
|
+
tables = adapter.list_tables(conn, namespace.value)
|
|
120
|
+
duration_ms = _elapsed_ms(started)
|
|
121
|
+
self._audit.log(
|
|
122
|
+
tool="list_tables",
|
|
123
|
+
connection_id=connection_id,
|
|
124
|
+
success=True,
|
|
125
|
+
duration_ms=duration_ms,
|
|
126
|
+
row_count=len(tables),
|
|
127
|
+
extra={"engine": config.engine, namespace.field_name: namespace.value},
|
|
128
|
+
)
|
|
129
|
+
return {
|
|
130
|
+
"connection_id": connection_id,
|
|
131
|
+
"engine": config.engine,
|
|
132
|
+
namespace.field_name: namespace.value,
|
|
133
|
+
"tables": tables,
|
|
134
|
+
}
|
|
135
|
+
except Exception as exc:
|
|
136
|
+
duration_ms = _elapsed_ms(started)
|
|
137
|
+
sanitized = sanitize_error_message(str(exc))
|
|
138
|
+
self._audit.log(
|
|
139
|
+
tool="list_tables",
|
|
140
|
+
connection_id=connection_id,
|
|
141
|
+
success=False,
|
|
142
|
+
duration_ms=duration_ms,
|
|
143
|
+
error=sanitized,
|
|
144
|
+
extra=_build_audit_extra(config, schema=schema, database=database),
|
|
145
|
+
)
|
|
146
|
+
raise QueryExecutionError(sanitized) from exc
|
|
147
|
+
|
|
148
|
+
def describe_table(
|
|
149
|
+
self,
|
|
150
|
+
connection_id: str,
|
|
151
|
+
table_name: str,
|
|
152
|
+
schema: Optional[str] = None,
|
|
153
|
+
database: Optional[str] = None,
|
|
154
|
+
) -> Dict[str, object]:
|
|
155
|
+
started = time.perf_counter()
|
|
156
|
+
config = None
|
|
157
|
+
try:
|
|
158
|
+
config = self._registry.get_connection_config(connection_id)
|
|
159
|
+
namespace = resolve_namespace(config, schema=schema, database=database)
|
|
160
|
+
with self._registry.connection_from_config(config) as (conn, adapter):
|
|
161
|
+
_apply_statement_timeout(
|
|
162
|
+
adapter, conn, self._settings.statement_timeout_ms
|
|
163
|
+
)
|
|
164
|
+
description = adapter.describe_table(conn, namespace.value, table_name)
|
|
165
|
+
if not description:
|
|
166
|
+
raise QueryExecutionError(
|
|
167
|
+
f"未找到表 {namespace.value}.{table_name},或当前用户没有访问权限"
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
duration_ms = _elapsed_ms(started)
|
|
171
|
+
self._audit.log(
|
|
172
|
+
tool="describe_table",
|
|
173
|
+
connection_id=connection_id,
|
|
174
|
+
success=True,
|
|
175
|
+
duration_ms=duration_ms,
|
|
176
|
+
row_count=len(description["columns"]),
|
|
177
|
+
extra={
|
|
178
|
+
"engine": config.engine,
|
|
179
|
+
namespace.field_name: namespace.value,
|
|
180
|
+
"table_name": table_name,
|
|
181
|
+
},
|
|
182
|
+
)
|
|
183
|
+
return {
|
|
184
|
+
"connection_id": connection_id,
|
|
185
|
+
"engine": config.engine,
|
|
186
|
+
namespace.field_name: namespace.value,
|
|
187
|
+
"table_name": table_name,
|
|
188
|
+
"columns": description["columns"],
|
|
189
|
+
"indexes": description["indexes"],
|
|
190
|
+
}
|
|
191
|
+
except Exception as exc:
|
|
192
|
+
duration_ms = _elapsed_ms(started)
|
|
193
|
+
sanitized = sanitize_error_message(str(exc))
|
|
194
|
+
self._audit.log(
|
|
195
|
+
tool="describe_table",
|
|
196
|
+
connection_id=connection_id,
|
|
197
|
+
success=False,
|
|
198
|
+
duration_ms=duration_ms,
|
|
199
|
+
error=sanitized,
|
|
200
|
+
extra=_build_audit_extra(
|
|
201
|
+
config,
|
|
202
|
+
schema=schema,
|
|
203
|
+
database=database,
|
|
204
|
+
table_name=table_name,
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
raise QueryExecutionError(sanitized) from exc
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _elapsed_ms(started: float) -> int:
|
|
211
|
+
return int((time.perf_counter() - started) * 1000)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _apply_statement_timeout(
|
|
215
|
+
adapter: Any, conn: Any, timeout_ms: Optional[int]
|
|
216
|
+
) -> None:
|
|
217
|
+
if timeout_ms is not None:
|
|
218
|
+
getattr(adapter, "set_statement_timeout")(conn, timeout_ms)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _build_audit_extra(config, **kwargs: object) -> Dict[str, object]:
|
|
222
|
+
extra = dict(kwargs)
|
|
223
|
+
if config is not None:
|
|
224
|
+
extra["engine"] = config.engine
|
|
225
|
+
return extra
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Namespace resolution helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from .config import ConnectionConfig
|
|
9
|
+
from .errors import SecurityError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class NamespaceSelection:
|
|
14
|
+
field_name: str
|
|
15
|
+
value: str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def resolve_namespace(
|
|
19
|
+
config: ConnectionConfig,
|
|
20
|
+
*,
|
|
21
|
+
schema: Optional[str] = None,
|
|
22
|
+
database: Optional[str] = None,
|
|
23
|
+
) -> NamespaceSelection:
|
|
24
|
+
if schema and database:
|
|
25
|
+
raise SecurityError("schema 和 database 不能同时传入。")
|
|
26
|
+
|
|
27
|
+
if config.engine == "postgres":
|
|
28
|
+
if database:
|
|
29
|
+
raise SecurityError("PostgreSQL 连接不接受 database 参数。")
|
|
30
|
+
resolved = schema or config.default_schema
|
|
31
|
+
if not resolved:
|
|
32
|
+
raise SecurityError("PostgreSQL 连接必须显式传 schema,或在配置中设置 default_schema。")
|
|
33
|
+
return NamespaceSelection(field_name="schema", value=resolved)
|
|
34
|
+
|
|
35
|
+
if config.engine == "mysql":
|
|
36
|
+
if schema:
|
|
37
|
+
raise SecurityError("MySQL 连接不接受 schema 参数。")
|
|
38
|
+
resolved = database or config.default_database
|
|
39
|
+
if not resolved:
|
|
40
|
+
raise SecurityError("MySQL 连接必须显式传 database,或在配置中设置 default_database。")
|
|
41
|
+
return NamespaceSelection(field_name="database", value=resolved)
|
|
42
|
+
|
|
43
|
+
raise SecurityError(f"未知 engine: {config.engine}")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def require_engine(config: ConnectionConfig, engine: str, tool_name: str) -> None:
|
|
47
|
+
if config.engine != engine:
|
|
48
|
+
raise SecurityError(f"{tool_name} 仅适用于 {engine} 连接,当前连接 engine={config.engine}")
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Connection registry and adapter routing."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from typing import Any, Iterator, Tuple
|
|
8
|
+
|
|
9
|
+
from .adapters import MySQLAdapter, PostgresAdapter
|
|
10
|
+
from .config import AppConfig, ConnectionConfig
|
|
11
|
+
from .errors import ConfigurationError, ConnectionNotFoundError
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ConnectionRegistry:
|
|
15
|
+
"""Resolve connection config and route to the correct engine adapter."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, app_config: AppConfig):
|
|
18
|
+
self._config = app_config
|
|
19
|
+
self._adapters = {
|
|
20
|
+
"postgres": PostgresAdapter(),
|
|
21
|
+
"mysql": MySQLAdapter(),
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
def list_connections(self):
|
|
25
|
+
return [item.summary for item in self._config.connections]
|
|
26
|
+
|
|
27
|
+
def get_connection_config(self, connection_id: str) -> ConnectionConfig:
|
|
28
|
+
try:
|
|
29
|
+
config = self._config.connection_map[connection_id]
|
|
30
|
+
except KeyError as exc:
|
|
31
|
+
raise ConnectionNotFoundError(
|
|
32
|
+
f"未知 connection_id: {connection_id}"
|
|
33
|
+
) from exc
|
|
34
|
+
if not config.enabled:
|
|
35
|
+
raise ConnectionNotFoundError(f"connection_id 已被禁用: {connection_id}")
|
|
36
|
+
return config
|
|
37
|
+
|
|
38
|
+
@contextmanager
|
|
39
|
+
def connection(
|
|
40
|
+
self, connection_id: str
|
|
41
|
+
) -> Iterator[Tuple[Any, ConnectionConfig, Any]]:
|
|
42
|
+
config = self.get_connection_config(connection_id)
|
|
43
|
+
with self.connection_from_config(config) as (conn, adapter):
|
|
44
|
+
yield conn, config, adapter
|
|
45
|
+
|
|
46
|
+
@contextmanager
|
|
47
|
+
def connection_from_config(
|
|
48
|
+
self, config: ConnectionConfig
|
|
49
|
+
) -> Iterator[Tuple[Any, Any]]:
|
|
50
|
+
dsn = os.environ.get(config.dsn_env)
|
|
51
|
+
if not dsn:
|
|
52
|
+
raise ConfigurationError(
|
|
53
|
+
f"{config.connection_id} 缺少环境变量 {config.dsn_env},无法建立数据库连接"
|
|
54
|
+
)
|
|
55
|
+
adapter = self.get_adapter(config)
|
|
56
|
+
with adapter.connection(config.connection_id, dsn) as conn:
|
|
57
|
+
yield conn, adapter
|
|
58
|
+
|
|
59
|
+
def close(self) -> None:
|
|
60
|
+
for adapter in self._adapters.values():
|
|
61
|
+
adapter.close()
|
|
62
|
+
|
|
63
|
+
def get_adapter(self, config: ConnectionConfig):
|
|
64
|
+
try:
|
|
65
|
+
return self._adapters[config.engine]
|
|
66
|
+
except KeyError as exc:
|
|
67
|
+
raise ConfigurationError(f"不支持的 engine: {config.engine}") from exc
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import re
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import tomllib
|
|
10
|
+
except ModuleNotFoundError: # pragma: no cover - Python 3.10 fallback
|
|
11
|
+
import tomli as tomllib # type: ignore[import-not-found]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
TAG_PATTERN = re.compile(r"^v\d+\.\d+\.\d+$")
|
|
15
|
+
VERSION_PATTERN = re.compile(r"^\d+\.\d+\.\d+$")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class ReleaseContext:
|
|
20
|
+
tag: str
|
|
21
|
+
version: str
|
|
22
|
+
release_branch: str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def resolve_effective_tag(
|
|
26
|
+
event_name: str, github_ref_name: str, input_tag: str | None
|
|
27
|
+
) -> str:
|
|
28
|
+
if event_name == "workflow_dispatch":
|
|
29
|
+
if not input_tag:
|
|
30
|
+
raise ValueError("workflow_dispatch requires an explicit tag")
|
|
31
|
+
return input_tag
|
|
32
|
+
return github_ref_name
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def parse_version_tag(tag: str) -> str:
|
|
36
|
+
if not TAG_PATTERN.match(tag):
|
|
37
|
+
raise ValueError("release tags must match vX.Y.Z")
|
|
38
|
+
return tag[1:]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def read_project_version(pyproject_path: Path) -> str:
|
|
42
|
+
data = tomllib.loads(pyproject_path.read_text(encoding="utf-8"))
|
|
43
|
+
version = data.get("project", {}).get("version")
|
|
44
|
+
if not version:
|
|
45
|
+
raise ValueError("missing project.version")
|
|
46
|
+
if not VERSION_PATTERN.match(version):
|
|
47
|
+
raise ValueError("project.version must match X.Y.Z")
|
|
48
|
+
return version
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def build_release_context(tag: str, pyproject_path: Path) -> ReleaseContext:
|
|
52
|
+
version = parse_version_tag(tag)
|
|
53
|
+
project_version = read_project_version(pyproject_path)
|
|
54
|
+
if version != project_version:
|
|
55
|
+
raise ValueError("tag version does not match pyproject version")
|
|
56
|
+
return ReleaseContext(tag=tag, version=version, release_branch=f"release/{tag}")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def should_skip_pypi_upload(
|
|
60
|
+
is_recovery_run: bool,
|
|
61
|
+
pypi_version_exists: bool,
|
|
62
|
+
recovery_confirmed: bool,
|
|
63
|
+
) -> bool:
|
|
64
|
+
return is_recovery_run and pypi_version_exists and recovery_confirmed
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def decide_backmerge_action(target: str, has_open_pr: bool, has_diff: bool) -> str:
|
|
68
|
+
if has_open_pr:
|
|
69
|
+
return "reuse"
|
|
70
|
+
if target == "main":
|
|
71
|
+
return "create"
|
|
72
|
+
if has_diff:
|
|
73
|
+
return "create"
|
|
74
|
+
return "skip"
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _build_parser() -> argparse.ArgumentParser:
|
|
78
|
+
parser = argparse.ArgumentParser(description="Resolve release metadata.")
|
|
79
|
+
parser.add_argument("--tag", required=True)
|
|
80
|
+
parser.add_argument("--pyproject", required=True)
|
|
81
|
+
return parser
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def main() -> None:
|
|
85
|
+
args = _build_parser().parse_args()
|
|
86
|
+
context = build_release_context(args.tag, Path(args.pyproject))
|
|
87
|
+
print(f"tag={context.tag}")
|
|
88
|
+
print(f"version={context.version}")
|
|
89
|
+
print(f"release_branch={context.release_branch}")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == "__main__":
|
|
93
|
+
main()
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""SQL validation and normalization."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from typing import Any, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
from sqlglot import exp, parse_one
|
|
9
|
+
from sqlglot.errors import ParseError
|
|
10
|
+
|
|
11
|
+
from .errors import SecurityError
|
|
12
|
+
|
|
13
|
+
COMMENT_TOKENS = ("--", "/*", "*/")
|
|
14
|
+
READ_ONLY_ROOT_TYPES = tuple(
|
|
15
|
+
expr_type
|
|
16
|
+
for expr_type in (
|
|
17
|
+
getattr(exp, name, None) for name in ("Select", "Union", "Except", "Intersect")
|
|
18
|
+
)
|
|
19
|
+
if isinstance(expr_type, type)
|
|
20
|
+
)
|
|
21
|
+
MUTATING_EXPRESSION_TYPES = tuple(
|
|
22
|
+
expr_type
|
|
23
|
+
for expr_type in (
|
|
24
|
+
getattr(exp, name, None)
|
|
25
|
+
for name in (
|
|
26
|
+
"Insert",
|
|
27
|
+
"Update",
|
|
28
|
+
"Delete",
|
|
29
|
+
"Merge",
|
|
30
|
+
"Create",
|
|
31
|
+
"Drop",
|
|
32
|
+
"Alter",
|
|
33
|
+
"Command",
|
|
34
|
+
"Copy",
|
|
35
|
+
"Call",
|
|
36
|
+
"Set",
|
|
37
|
+
"Pragma",
|
|
38
|
+
"Use",
|
|
39
|
+
"Grant",
|
|
40
|
+
"Revoke",
|
|
41
|
+
"Transaction",
|
|
42
|
+
"Commit",
|
|
43
|
+
"Rollback",
|
|
44
|
+
"TruncateTable",
|
|
45
|
+
"Vacuum",
|
|
46
|
+
"Into",
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
if isinstance(expr_type, type)
|
|
50
|
+
)
|
|
51
|
+
DIALECT_BY_ENGINE = {"postgres": "postgres", "mysql": "mysql"}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def validate_select_sql(sql: str, engine: str) -> str:
|
|
55
|
+
cleaned = _clean_sql(sql)
|
|
56
|
+
lowered = cleaned.lstrip().lower()
|
|
57
|
+
if lowered.startswith("explain"):
|
|
58
|
+
raise SecurityError(
|
|
59
|
+
"explain_query 会自动包装 SQL,请直接传 SELECT 或 WITH 查询。"
|
|
60
|
+
)
|
|
61
|
+
if not (lowered.startswith("select") or lowered.startswith("with")):
|
|
62
|
+
raise SecurityError("仅允许 SELECT 或 WITH ... SELECT 语句。")
|
|
63
|
+
statement = _parse_statement(cleaned, engine)
|
|
64
|
+
_ensure_read_only_statement(statement)
|
|
65
|
+
return cleaned
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def clamp_limit(limit: Optional[int], default_limit: int, max_limit: int) -> int:
|
|
69
|
+
value = default_limit if limit is None else int(limit)
|
|
70
|
+
if value <= 0:
|
|
71
|
+
raise SecurityError("limit 必须大于 0。")
|
|
72
|
+
return min(value, max_limit)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def build_limited_query(sql: str, row_limit: int) -> Tuple[str, int]:
|
|
76
|
+
sentinel_limit = row_limit + 1
|
|
77
|
+
wrapped = f"SELECT * FROM ({sql}) AS _pq_result LIMIT {sentinel_limit}"
|
|
78
|
+
return wrapped, sentinel_limit
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def summarize_sql(sql: str, max_chars: int = 160) -> str:
|
|
82
|
+
one_line = re.sub(r"\s+", " ", sql).strip()
|
|
83
|
+
if len(one_line) <= max_chars:
|
|
84
|
+
return one_line
|
|
85
|
+
return one_line[: max_chars - 3] + "..."
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _clean_sql(sql: str) -> str:
|
|
89
|
+
if not sql or not sql.strip():
|
|
90
|
+
raise SecurityError("SQL 不能为空。")
|
|
91
|
+
|
|
92
|
+
cleaned = sql.strip()
|
|
93
|
+
for token in COMMENT_TOKENS:
|
|
94
|
+
if token in cleaned:
|
|
95
|
+
raise SecurityError("不允许使用 SQL 注释。")
|
|
96
|
+
|
|
97
|
+
semicolon_count = cleaned.count(";")
|
|
98
|
+
if semicolon_count > 1:
|
|
99
|
+
raise SecurityError("只允许单条 SQL 语句。")
|
|
100
|
+
if semicolon_count == 1 and not cleaned.endswith(";"):
|
|
101
|
+
raise SecurityError("只允许单条 SQL 语句。")
|
|
102
|
+
if cleaned.endswith(";"):
|
|
103
|
+
cleaned = cleaned[:-1].rstrip()
|
|
104
|
+
|
|
105
|
+
if not cleaned:
|
|
106
|
+
raise SecurityError("SQL 不能为空。")
|
|
107
|
+
return cleaned
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _parse_statement(sql: str, engine: str) -> Any:
|
|
111
|
+
try:
|
|
112
|
+
dialect = DIALECT_BY_ENGINE[engine]
|
|
113
|
+
except KeyError as exc:
|
|
114
|
+
raise SecurityError(f"不支持的 SQL 方言: {engine}") from exc
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
return parse_one(sql, dialect=dialect)
|
|
118
|
+
except ParseError as exc:
|
|
119
|
+
raise SecurityError(f"SQL 解析失败,已拒绝执行: {exc}") from exc
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _ensure_read_only_statement(statement: Any) -> None:
|
|
123
|
+
if not isinstance(statement, READ_ONLY_ROOT_TYPES):
|
|
124
|
+
raise SecurityError("仅允许 SELECT 或 WITH ... SELECT 语句。")
|
|
125
|
+
|
|
126
|
+
for node in statement.walk():
|
|
127
|
+
if isinstance(node, MUTATING_EXPRESSION_TYPES):
|
|
128
|
+
raise SecurityError(f"仅允许只读查询,检测到写操作: {node.key.upper()}")
|