sql-query-mcp 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_query_mcp/__init__.py +5 -0
- sql_query_mcp/__main__.py +5 -0
- sql_query_mcp/adapters/__init__.py +15 -0
- sql_query_mcp/adapters/mysql.py +171 -0
- sql_query_mcp/adapters/postgres.py +180 -0
- sql_query_mcp/app.py +105 -0
- sql_query_mcp/audit.py +42 -0
- sql_query_mcp/config.py +255 -0
- sql_query_mcp/errors.py +34 -0
- sql_query_mcp/executor.py +243 -0
- sql_query_mcp/introspection.py +225 -0
- sql_query_mcp/namespace.py +48 -0
- sql_query_mcp/registry.py +67 -0
- sql_query_mcp/release_metadata.py +93 -0
- sql_query_mcp/validator.py +128 -0
- sql_query_mcp-0.1.1.dist-info/METADATA +235 -0
- sql_query_mcp-0.1.1.dist-info/RECORD +21 -0
- sql_query_mcp-0.1.1.dist-info/WHEEL +5 -0
- sql_query_mcp-0.1.1.dist-info/entry_points.txt +2 -0
- sql_query_mcp-0.1.1.dist-info/licenses/LICENSE +21 -0
- sql_query_mcp-0.1.1.dist-info/top_level.txt +1 -0
sql_query_mcp/config.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""Configuration loading for sql-query-mcp."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, Iterable, List, Optional
|
|
11
|
+
|
|
12
|
+
from .errors import ConfigurationError
|
|
13
|
+
|
|
14
|
+
PACKAGE_ROOT = Path(__file__).resolve().parent.parent
|
|
15
|
+
DEFAULT_CONFIG_ENV = "SQL_QUERY_MCP_CONFIG"
|
|
16
|
+
DEFAULT_CONFIG_PATH = PACKAGE_ROOT / "config" / "connections.json"
|
|
17
|
+
DEFAULT_AUDIT_LOG_PATH = PACKAGE_ROOT / "logs" / "audit.jsonl"
|
|
18
|
+
CONNECTION_ID_RE = re.compile(r"^[a-z0-9]+(?:_[a-z0-9]+){3,}$")
|
|
19
|
+
SUPPORTED_ENGINES = {"postgres", "mysql"}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True)
|
|
23
|
+
class ServerSettings:
|
|
24
|
+
default_limit: int = 200
|
|
25
|
+
max_limit: int = 1000
|
|
26
|
+
statement_timeout_ms: Optional[int] = None
|
|
27
|
+
audit_log_path: Path = DEFAULT_AUDIT_LOG_PATH
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class ConnectionConfig:
|
|
32
|
+
connection_id: str
|
|
33
|
+
engine: str
|
|
34
|
+
env: str
|
|
35
|
+
tenant: str
|
|
36
|
+
role: str
|
|
37
|
+
dsn_env: str
|
|
38
|
+
enabled: bool = True
|
|
39
|
+
label: Optional[str] = None
|
|
40
|
+
description: Optional[str] = None
|
|
41
|
+
default_schema: Optional[str] = None
|
|
42
|
+
default_database: Optional[str] = None
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def summary(self) -> Dict[str, object]:
|
|
46
|
+
return {
|
|
47
|
+
"connection_id": self.connection_id,
|
|
48
|
+
"engine": self.engine,
|
|
49
|
+
"label": self.label or self.connection_id,
|
|
50
|
+
"env": self.env,
|
|
51
|
+
"tenant": self.tenant,
|
|
52
|
+
"role": self.role,
|
|
53
|
+
"enabled": self.enabled,
|
|
54
|
+
"default_schema": self.default_schema,
|
|
55
|
+
"default_database": self.default_database,
|
|
56
|
+
"description": self.description,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass(frozen=True)
|
|
61
|
+
class AppConfig:
|
|
62
|
+
settings: ServerSettings
|
|
63
|
+
connections: List[ConnectionConfig]
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def connection_map(self) -> Dict[str, ConnectionConfig]:
|
|
67
|
+
return {item.connection_id: item for item in self.connections}
|
|
68
|
+
|
|
69
|
+
def enabled_connections(self) -> Iterable[ConnectionConfig]:
|
|
70
|
+
return (item for item in self.connections if item.enabled)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def resolve_config_path(config_path: Optional[str] = None) -> Path:
|
|
74
|
+
raw_path = config_path or os.environ.get(DEFAULT_CONFIG_ENV)
|
|
75
|
+
return Path(raw_path).expanduser().resolve() if raw_path else DEFAULT_CONFIG_PATH
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def load_config(config_path: Optional[str] = None) -> AppConfig:
|
|
79
|
+
path = resolve_config_path(config_path)
|
|
80
|
+
if not path.exists():
|
|
81
|
+
return AppConfig(settings=ServerSettings(), connections=[])
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
|
85
|
+
except json.JSONDecodeError as exc:
|
|
86
|
+
raise ConfigurationError(f"配置文件不是有效 JSON: {path}") from exc
|
|
87
|
+
|
|
88
|
+
settings = _parse_settings(payload.get("settings", {}), path)
|
|
89
|
+
connections = _parse_connections(payload.get("connections", []))
|
|
90
|
+
return AppConfig(settings=settings, connections=connections)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _parse_settings(data: Dict[str, object], path: Path) -> ServerSettings:
|
|
94
|
+
default_limit = _required_int(data.get("default_limit", 200), "default_limit")
|
|
95
|
+
max_limit = _required_int(data.get("max_limit", 1000), "max_limit")
|
|
96
|
+
statement_timeout_ms = _optional_positive_int(
|
|
97
|
+
data.get("statement_timeout_ms"), "statement_timeout_ms"
|
|
98
|
+
)
|
|
99
|
+
audit_log_raw = str(
|
|
100
|
+
data.get("audit_log_path", DEFAULT_AUDIT_LOG_PATH.relative_to(PACKAGE_ROOT))
|
|
101
|
+
)
|
|
102
|
+
audit_log_path = Path(audit_log_raw)
|
|
103
|
+
if not audit_log_path.is_absolute():
|
|
104
|
+
audit_log_path = (path.parent / audit_log_path).resolve()
|
|
105
|
+
|
|
106
|
+
if default_limit <= 0:
|
|
107
|
+
raise ConfigurationError("default_limit 必须大于 0")
|
|
108
|
+
if max_limit < default_limit:
|
|
109
|
+
raise ConfigurationError("max_limit 不能小于 default_limit")
|
|
110
|
+
return ServerSettings(
|
|
111
|
+
default_limit=default_limit,
|
|
112
|
+
max_limit=max_limit,
|
|
113
|
+
statement_timeout_ms=statement_timeout_ms,
|
|
114
|
+
audit_log_path=audit_log_path,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _parse_connections(items: object) -> List[ConnectionConfig]:
|
|
119
|
+
if not isinstance(items, list):
|
|
120
|
+
raise ConfigurationError("connections 必须是数组")
|
|
121
|
+
|
|
122
|
+
result: List[ConnectionConfig] = []
|
|
123
|
+
seen = set()
|
|
124
|
+
for item in items:
|
|
125
|
+
if not isinstance(item, dict):
|
|
126
|
+
raise ConfigurationError("connections 数组中的每一项都必须是对象")
|
|
127
|
+
|
|
128
|
+
connection_id = str(item.get("connection_id", "")).strip()
|
|
129
|
+
engine = str(item.get("engine", "")).strip()
|
|
130
|
+
label = _required_string(item, "label", connection_id or "connection")
|
|
131
|
+
enabled = _required_bool(item, "enabled", connection_id or "connection")
|
|
132
|
+
if not CONNECTION_ID_RE.match(connection_id):
|
|
133
|
+
raise ConfigurationError(
|
|
134
|
+
"connection_id 必须符合 <system>_<env>_<tenant>_<role> 风格,且只包含小写字母、数字、下划线"
|
|
135
|
+
)
|
|
136
|
+
if engine not in SUPPORTED_ENGINES:
|
|
137
|
+
raise ConfigurationError(
|
|
138
|
+
f"{connection_id} 缺少合法 engine,必须是 postgres 或 mysql"
|
|
139
|
+
)
|
|
140
|
+
if connection_id in seen:
|
|
141
|
+
raise ConfigurationError(f"重复的 connection_id: {connection_id}")
|
|
142
|
+
seen.add(connection_id)
|
|
143
|
+
|
|
144
|
+
dsn_env = str(item.get("dsn_env", "")).strip()
|
|
145
|
+
env = str(item.get("env", "")).strip()
|
|
146
|
+
tenant = str(item.get("tenant", "")).strip()
|
|
147
|
+
role = str(item.get("role", "")).strip()
|
|
148
|
+
if not all((dsn_env, env, tenant, role)):
|
|
149
|
+
raise ConfigurationError(
|
|
150
|
+
f"{connection_id} 缺少必要字段,必须提供 env / tenant / role / dsn_env"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
_reject_legacy_namespace_fields(item, connection_id)
|
|
154
|
+
default_schema = _optional_string(item.get("default_schema"))
|
|
155
|
+
default_database = _optional_string(item.get("default_database"))
|
|
156
|
+
|
|
157
|
+
if engine == "postgres" and "default_database" in item:
|
|
158
|
+
raise ConfigurationError(
|
|
159
|
+
f"{connection_id} 是 PostgreSQL 连接,不能配置 default_database"
|
|
160
|
+
)
|
|
161
|
+
if engine == "mysql" and "default_schema" in item:
|
|
162
|
+
raise ConfigurationError(
|
|
163
|
+
f"{connection_id} 是 MySQL 连接,不能配置 default_schema"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
result.append(
|
|
167
|
+
ConnectionConfig(
|
|
168
|
+
connection_id=connection_id,
|
|
169
|
+
engine=engine,
|
|
170
|
+
label=label,
|
|
171
|
+
description=(
|
|
172
|
+
str(item["description"]).strip()
|
|
173
|
+
if item.get("description")
|
|
174
|
+
else None
|
|
175
|
+
),
|
|
176
|
+
env=env,
|
|
177
|
+
tenant=tenant,
|
|
178
|
+
role=role,
|
|
179
|
+
dsn_env=dsn_env,
|
|
180
|
+
enabled=enabled,
|
|
181
|
+
default_schema=default_schema,
|
|
182
|
+
default_database=default_database,
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return result
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _reject_legacy_namespace_fields(
|
|
190
|
+
item: Dict[str, object], connection_id: str
|
|
191
|
+
) -> None:
|
|
192
|
+
if "default_namespace" in item:
|
|
193
|
+
raise ConfigurationError(
|
|
194
|
+
f"{connection_id} 仍在使用 default_namespace,请改为 default_schema 或 default_database"
|
|
195
|
+
)
|
|
196
|
+
if "default_schemas" in item:
|
|
197
|
+
raise ConfigurationError(
|
|
198
|
+
f"{connection_id} 仍在使用 default_schemas,请收敛为单值字段 default_schema"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _optional_string(value: object) -> Optional[str]:
|
|
203
|
+
if value is None:
|
|
204
|
+
return None
|
|
205
|
+
text = str(value).strip()
|
|
206
|
+
return text or None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _optional_positive_int(value: Any, field_name: str) -> Optional[int]:
|
|
210
|
+
if value is None:
|
|
211
|
+
return None
|
|
212
|
+
parsed = _required_int(value, field_name)
|
|
213
|
+
if parsed <= 0:
|
|
214
|
+
raise ConfigurationError(f"{field_name} 必须是大于 0 的整数")
|
|
215
|
+
return parsed
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _required_int(value: Any, field_name: str) -> int:
|
|
219
|
+
if isinstance(value, bool):
|
|
220
|
+
raise ConfigurationError(f"{field_name} 必须是整数")
|
|
221
|
+
if isinstance(value, int):
|
|
222
|
+
return value
|
|
223
|
+
if isinstance(value, float):
|
|
224
|
+
if value.is_integer():
|
|
225
|
+
return int(value)
|
|
226
|
+
raise ConfigurationError(f"{field_name} 必须是整数")
|
|
227
|
+
if isinstance(value, str):
|
|
228
|
+
try:
|
|
229
|
+
return int(value)
|
|
230
|
+
except ValueError as exc:
|
|
231
|
+
raise ConfigurationError(f"{field_name} 必须是整数") from exc
|
|
232
|
+
try:
|
|
233
|
+
return value.__int__()
|
|
234
|
+
except (TypeError, ValueError) as exc:
|
|
235
|
+
raise ConfigurationError(f"{field_name} 必须是整数") from exc
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _required_string(
|
|
239
|
+
item: Dict[str, object], field_name: str, connection_id: str
|
|
240
|
+
) -> str:
|
|
241
|
+
value = _optional_string(item.get(field_name))
|
|
242
|
+
if value is None:
|
|
243
|
+
raise ConfigurationError(f"{connection_id} 缺少必要字段 {field_name}")
|
|
244
|
+
return value
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _required_bool(
|
|
248
|
+
item: Dict[str, object], field_name: str, connection_id: str
|
|
249
|
+
) -> bool:
|
|
250
|
+
if field_name not in item:
|
|
251
|
+
raise ConfigurationError(f"{connection_id} 缺少必要字段 {field_name}")
|
|
252
|
+
value = item[field_name]
|
|
253
|
+
if not isinstance(value, bool):
|
|
254
|
+
raise ConfigurationError(f"{connection_id} 的 {field_name} 必须是布尔值")
|
|
255
|
+
return value
|
sql_query_mcp/errors.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Error types for sql-query-mcp."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SqlQueryMCPError(Exception):
|
|
9
|
+
"""Base error for this package."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConfigurationError(SqlQueryMCPError):
|
|
13
|
+
"""Raised when local configuration is invalid."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ConnectionNotFoundError(SqlQueryMCPError):
|
|
17
|
+
"""Raised when the requested connection_id does not exist."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SecurityError(SqlQueryMCPError):
|
|
21
|
+
"""Raised when SQL validation rejects a query."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class QueryExecutionError(SqlQueryMCPError):
|
|
25
|
+
"""Raised when the database execution layer fails."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
_DSN_CREDENTIALS_RE = re.compile(r"://([^:@/\s]+):([^@/\s]+)@")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def sanitize_error_message(message: str) -> str:
|
|
32
|
+
"""Mask DSN credentials before surfacing an error to the model."""
|
|
33
|
+
|
|
34
|
+
return _DSN_CREDENTIALS_RE.sub(r"://\1:***@", message)
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
"""Query execution helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
|
+
|
|
8
|
+
from .audit import AuditLogger
|
|
9
|
+
from .config import ServerSettings
|
|
10
|
+
from .errors import QueryExecutionError, sanitize_error_message
|
|
11
|
+
from .namespace import resolve_namespace
|
|
12
|
+
from .registry import ConnectionRegistry
|
|
13
|
+
from .validator import (
|
|
14
|
+
build_limited_query,
|
|
15
|
+
clamp_limit,
|
|
16
|
+
summarize_sql,
|
|
17
|
+
validate_select_sql,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class QueryExecutor:
|
|
22
|
+
"""Execute validated read-only SQL."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
registry: ConnectionRegistry,
|
|
27
|
+
settings: ServerSettings,
|
|
28
|
+
audit_logger: AuditLogger,
|
|
29
|
+
):
|
|
30
|
+
self._registry = registry
|
|
31
|
+
self._settings = settings
|
|
32
|
+
self._audit = audit_logger
|
|
33
|
+
|
|
34
|
+
def run_select(
|
|
35
|
+
self, connection_id: str, sql_text: str, limit: Optional[int] = None
|
|
36
|
+
) -> Dict[str, object]:
|
|
37
|
+
started = time.perf_counter()
|
|
38
|
+
config = None
|
|
39
|
+
row_limit = clamp_limit(
|
|
40
|
+
limit, self._settings.default_limit, self._settings.max_limit
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
config = self._registry.get_connection_config(connection_id)
|
|
45
|
+
cleaned_sql = validate_select_sql(sql_text, config.engine)
|
|
46
|
+
limited_sql, _ = build_limited_query(cleaned_sql, row_limit)
|
|
47
|
+
sql_summary = summarize_sql(cleaned_sql)
|
|
48
|
+
with self._registry.connection_from_config(config) as (conn, adapter):
|
|
49
|
+
_apply_statement_timeout(
|
|
50
|
+
adapter, conn, self._settings.statement_timeout_ms
|
|
51
|
+
)
|
|
52
|
+
with conn.cursor() as cur:
|
|
53
|
+
cur.execute(limited_sql)
|
|
54
|
+
columns = adapter.column_names(cur.description)
|
|
55
|
+
rows = cur.fetchall()
|
|
56
|
+
|
|
57
|
+
truncated = len(rows) > row_limit
|
|
58
|
+
trimmed_rows = rows[:row_limit]
|
|
59
|
+
duration_ms = _elapsed_ms(started)
|
|
60
|
+
self._audit.log(
|
|
61
|
+
tool="run_select",
|
|
62
|
+
connection_id=connection_id,
|
|
63
|
+
success=True,
|
|
64
|
+
duration_ms=duration_ms,
|
|
65
|
+
row_count=len(trimmed_rows),
|
|
66
|
+
sql_summary=sql_summary,
|
|
67
|
+
extra={
|
|
68
|
+
"engine": config.engine,
|
|
69
|
+
"limit": row_limit,
|
|
70
|
+
"truncated": truncated,
|
|
71
|
+
},
|
|
72
|
+
)
|
|
73
|
+
return {
|
|
74
|
+
"connection_id": connection_id,
|
|
75
|
+
"engine": config.engine,
|
|
76
|
+
"columns": columns,
|
|
77
|
+
"rows": trimmed_rows,
|
|
78
|
+
"row_count": len(trimmed_rows),
|
|
79
|
+
"truncated": truncated,
|
|
80
|
+
"duration_ms": duration_ms,
|
|
81
|
+
"applied_limit": row_limit,
|
|
82
|
+
}
|
|
83
|
+
except Exception as exc:
|
|
84
|
+
duration_ms = _elapsed_ms(started)
|
|
85
|
+
sql_summary = summarize_sql(sql_text)
|
|
86
|
+
sanitized = sanitize_error_message(str(exc))
|
|
87
|
+
self._audit.log(
|
|
88
|
+
tool="run_select",
|
|
89
|
+
connection_id=connection_id,
|
|
90
|
+
success=False,
|
|
91
|
+
duration_ms=duration_ms,
|
|
92
|
+
sql_summary=sql_summary,
|
|
93
|
+
error=sanitized,
|
|
94
|
+
extra=_build_audit_extra(config, limit=row_limit),
|
|
95
|
+
)
|
|
96
|
+
raise QueryExecutionError(sanitized) from exc
|
|
97
|
+
|
|
98
|
+
def explain_query(
|
|
99
|
+
self, connection_id: str, sql_text: str, analyze: bool = False
|
|
100
|
+
) -> Dict[str, object]:
|
|
101
|
+
started = time.perf_counter()
|
|
102
|
+
config = None
|
|
103
|
+
try:
|
|
104
|
+
config = self._registry.get_connection_config(connection_id)
|
|
105
|
+
cleaned_sql = validate_select_sql(sql_text, config.engine)
|
|
106
|
+
sql_summary = summarize_sql(cleaned_sql)
|
|
107
|
+
adapter = self._registry.get_adapter(config)
|
|
108
|
+
explain_sql = adapter.build_explain_query(cleaned_sql, analyze=analyze)
|
|
109
|
+
with self._registry.connection_from_config(config) as (conn, adapter):
|
|
110
|
+
_apply_statement_timeout(
|
|
111
|
+
adapter, conn, self._settings.statement_timeout_ms
|
|
112
|
+
)
|
|
113
|
+
with conn.cursor() as cur:
|
|
114
|
+
cur.execute(explain_sql)
|
|
115
|
+
rows = cur.fetchall()
|
|
116
|
+
|
|
117
|
+
duration_ms = _elapsed_ms(started)
|
|
118
|
+
plan = adapter.extract_plan(rows)
|
|
119
|
+
self._audit.log(
|
|
120
|
+
tool="explain_query",
|
|
121
|
+
connection_id=connection_id,
|
|
122
|
+
success=True,
|
|
123
|
+
duration_ms=duration_ms,
|
|
124
|
+
row_count=1 if rows else 0,
|
|
125
|
+
sql_summary=sql_summary,
|
|
126
|
+
extra={"engine": config.engine, "analyze": analyze},
|
|
127
|
+
)
|
|
128
|
+
return {
|
|
129
|
+
"connection_id": connection_id,
|
|
130
|
+
"engine": config.engine,
|
|
131
|
+
"plan": plan,
|
|
132
|
+
"duration_ms": duration_ms,
|
|
133
|
+
"analyze": analyze,
|
|
134
|
+
}
|
|
135
|
+
except Exception as exc:
|
|
136
|
+
duration_ms = _elapsed_ms(started)
|
|
137
|
+
sql_summary = summarize_sql(sql_text)
|
|
138
|
+
sanitized = sanitize_error_message(str(exc))
|
|
139
|
+
self._audit.log(
|
|
140
|
+
tool="explain_query",
|
|
141
|
+
connection_id=connection_id,
|
|
142
|
+
success=False,
|
|
143
|
+
duration_ms=duration_ms,
|
|
144
|
+
sql_summary=sql_summary,
|
|
145
|
+
error=sanitized,
|
|
146
|
+
extra=_build_audit_extra(config, analyze=analyze),
|
|
147
|
+
)
|
|
148
|
+
raise QueryExecutionError(sanitized) from exc
|
|
149
|
+
|
|
150
|
+
def get_table_sample(
|
|
151
|
+
self,
|
|
152
|
+
connection_id: str,
|
|
153
|
+
table_name: str,
|
|
154
|
+
schema: Optional[str] = None,
|
|
155
|
+
database: Optional[str] = None,
|
|
156
|
+
limit: Optional[int] = None,
|
|
157
|
+
) -> Dict[str, object]:
|
|
158
|
+
row_limit = clamp_limit(
|
|
159
|
+
limit, self._settings.default_limit, self._settings.max_limit
|
|
160
|
+
)
|
|
161
|
+
started = time.perf_counter()
|
|
162
|
+
config = None
|
|
163
|
+
try:
|
|
164
|
+
config = self._registry.get_connection_config(connection_id)
|
|
165
|
+
namespace = resolve_namespace(config, schema=schema, database=database)
|
|
166
|
+
adapter = self._registry.get_adapter(config)
|
|
167
|
+
query = adapter.build_sample_query(
|
|
168
|
+
namespace.value, table_name, row_limit + 1
|
|
169
|
+
)
|
|
170
|
+
with self._registry.connection_from_config(config) as (conn, adapter):
|
|
171
|
+
_apply_statement_timeout(
|
|
172
|
+
adapter, conn, self._settings.statement_timeout_ms
|
|
173
|
+
)
|
|
174
|
+
with conn.cursor() as cur:
|
|
175
|
+
cur.execute(query)
|
|
176
|
+
columns = adapter.column_names(cur.description)
|
|
177
|
+
rows = cur.fetchall()
|
|
178
|
+
|
|
179
|
+
truncated = len(rows) > row_limit
|
|
180
|
+
trimmed_rows = rows[:row_limit]
|
|
181
|
+
duration_ms = _elapsed_ms(started)
|
|
182
|
+
self._audit.log(
|
|
183
|
+
tool="get_table_sample",
|
|
184
|
+
connection_id=connection_id,
|
|
185
|
+
success=True,
|
|
186
|
+
duration_ms=duration_ms,
|
|
187
|
+
row_count=len(trimmed_rows),
|
|
188
|
+
sql_summary=f"sample {namespace.value}.{table_name}",
|
|
189
|
+
extra={
|
|
190
|
+
"engine": config.engine,
|
|
191
|
+
namespace.field_name: namespace.value,
|
|
192
|
+
"table_name": table_name,
|
|
193
|
+
"limit": row_limit,
|
|
194
|
+
},
|
|
195
|
+
)
|
|
196
|
+
return {
|
|
197
|
+
"connection_id": connection_id,
|
|
198
|
+
"engine": config.engine,
|
|
199
|
+
namespace.field_name: namespace.value,
|
|
200
|
+
"table_name": table_name,
|
|
201
|
+
"columns": columns,
|
|
202
|
+
"rows": trimmed_rows,
|
|
203
|
+
"row_count": len(trimmed_rows),
|
|
204
|
+
"truncated": truncated,
|
|
205
|
+
"duration_ms": duration_ms,
|
|
206
|
+
"applied_limit": row_limit,
|
|
207
|
+
}
|
|
208
|
+
except Exception as exc:
|
|
209
|
+
duration_ms = _elapsed_ms(started)
|
|
210
|
+
sanitized = sanitize_error_message(str(exc))
|
|
211
|
+
self._audit.log(
|
|
212
|
+
tool="get_table_sample",
|
|
213
|
+
connection_id=connection_id,
|
|
214
|
+
success=False,
|
|
215
|
+
duration_ms=duration_ms,
|
|
216
|
+
error=sanitized,
|
|
217
|
+
extra=_build_audit_extra(
|
|
218
|
+
config,
|
|
219
|
+
schema=schema,
|
|
220
|
+
database=database,
|
|
221
|
+
table_name=table_name,
|
|
222
|
+
limit=row_limit,
|
|
223
|
+
),
|
|
224
|
+
)
|
|
225
|
+
raise QueryExecutionError(sanitized) from exc
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def _elapsed_ms(started: float) -> int:
|
|
229
|
+
return int((time.perf_counter() - started) * 1000)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _apply_statement_timeout(
|
|
233
|
+
adapter: Any, conn: Any, timeout_ms: Optional[int]
|
|
234
|
+
) -> None:
|
|
235
|
+
if timeout_ms is not None:
|
|
236
|
+
getattr(adapter, "set_statement_timeout")(conn, timeout_ms)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _build_audit_extra(config, **kwargs: object) -> Dict[str, object]:
|
|
240
|
+
extra = dict(kwargs)
|
|
241
|
+
if config is not None:
|
|
242
|
+
extra["engine"] = config.engine
|
|
243
|
+
return extra
|