sql-mcp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_mcp/__init__.py +64 -0
- sql_mcp/__main__.py +4 -0
- sql_mcp/agent_server.py +71 -0
- sql_mcp/api/__init__.py +9 -0
- sql_mcp/api/api_client_sql.py +469 -0
- sql_mcp/api_client.py +11 -0
- sql_mcp/auth.py +142 -0
- sql_mcp/dialects.py +183 -0
- sql_mcp/main_agent.json +14 -0
- sql_mcp/mcp/__init__.py +5 -0
- sql_mcp/mcp/mcp_sql.py +224 -0
- sql_mcp/mcp_config.json +33 -0
- sql_mcp/mcp_server.py +59 -0
- sql_mcp/safety.py +173 -0
- sql_mcp/sql_input_models.py +63 -0
- sql_mcp/sql_response_models.py +51 -0
- sql_mcp-0.1.0.dist-info/METADATA +242 -0
- sql_mcp-0.1.0.dist-info/RECORD +22 -0
- sql_mcp-0.1.0.dist-info/WHEEL +5 -0
- sql_mcp-0.1.0.dist-info/entry_points.txt +3 -0
- sql_mcp-0.1.0.dist-info/licenses/LICENSE +21 -0
- sql_mcp-0.1.0.dist-info/top_level.txt +1 -0
sql_mcp/__init__.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""sql-mcp: Generic SQL database API + MCP Server + A2A Server (SQLAlchemy 2.x Core)."""
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import inspect
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
__version__ = "0.1.0"
|
|
8
|
+
__all__: list[str] = []
|
|
9
|
+
|
|
10
|
+
CORE_MODULES = ["sql_mcp.api_client"]
|
|
11
|
+
OPTIONAL_MODULES = {
|
|
12
|
+
"sql_mcp.agent_server": "agent",
|
|
13
|
+
"sql_mcp.mcp_server": "mcp",
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _expose_members(module):
|
|
18
|
+
for name, obj in inspect.getmembers(module):
|
|
19
|
+
if (inspect.isclass(obj) or inspect.isfunction(obj)) and not name.startswith(
|
|
20
|
+
"_"
|
|
21
|
+
):
|
|
22
|
+
globals()[name] = obj
|
|
23
|
+
if name not in __all__:
|
|
24
|
+
__all__.append(name)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
for module_name in CORE_MODULES:
|
|
28
|
+
module = importlib.import_module(module_name)
|
|
29
|
+
_expose_members(module)
|
|
30
|
+
|
|
31
|
+
_loaded_optional_modules: dict = {}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _import_module_safely(module_name: str):
|
|
35
|
+
try:
|
|
36
|
+
return importlib.import_module(module_name)
|
|
37
|
+
except ImportError:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def __getattr__(name: str) -> Any:
|
|
42
|
+
if name == "_MCP_AVAILABLE":
|
|
43
|
+
mcp_key = next((k for k in OPTIONAL_MODULES if "mcp_server" in k), None)
|
|
44
|
+
return _import_module_safely(mcp_key) is not None if mcp_key else False
|
|
45
|
+
if name == "_AGENT_AVAILABLE":
|
|
46
|
+
agent_key = next((k for k in OPTIONAL_MODULES if "agent_server" in k), None)
|
|
47
|
+
return _import_module_safely(agent_key) is not None if agent_key else False
|
|
48
|
+
|
|
49
|
+
for module_name in OPTIONAL_MODULES:
|
|
50
|
+
if module_name not in _loaded_optional_modules:
|
|
51
|
+
module = _import_module_safely(module_name)
|
|
52
|
+
if module is not None:
|
|
53
|
+
_loaded_optional_modules[module_name] = module
|
|
54
|
+
_expose_members(module)
|
|
55
|
+
|
|
56
|
+
module = _loaded_optional_modules.get(module_name)
|
|
57
|
+
if module is not None and hasattr(module, name):
|
|
58
|
+
return getattr(module, name)
|
|
59
|
+
|
|
60
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def __dir__() -> list[str]:
|
|
64
|
+
return sorted(list(globals().keys()) + __all__)
|
sql_mcp/__main__.py
ADDED
sql_mcp/agent_server.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""Graph-based Pydantic AI agent server entry point for sql-mcp (CONCEPT:SQL-1.6)."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import warnings
|
|
7
|
+
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
|
|
10
|
+
logging.basicConfig(
|
|
11
|
+
level=logging.INFO,
|
|
12
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
13
|
+
)
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
DEFAULT_AGENT_NAME = None
|
|
17
|
+
DEFAULT_AGENT_DESCRIPTION = None
|
|
18
|
+
DEFAULT_AGENT_SYSTEM_PROMPT = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def agent_server():
|
|
22
|
+
"""Start the graph-based Pydantic AI agent server for SQL databases."""
|
|
23
|
+
from agent_utilities import (
|
|
24
|
+
build_system_prompt_from_workspace,
|
|
25
|
+
create_agent_parser,
|
|
26
|
+
create_agent_server,
|
|
27
|
+
initialize_workspace,
|
|
28
|
+
load_identity,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
global DEFAULT_AGENT_NAME, DEFAULT_AGENT_DESCRIPTION, DEFAULT_AGENT_SYSTEM_PROMPT
|
|
32
|
+
initialize_workspace()
|
|
33
|
+
meta = load_identity()
|
|
34
|
+
DEFAULT_AGENT_NAME = os.getenv("DEFAULT_AGENT_NAME", meta.get("name", "SQL Agent"))
|
|
35
|
+
DEFAULT_AGENT_DESCRIPTION = os.getenv(
|
|
36
|
+
"AGENT_DESCRIPTION",
|
|
37
|
+
meta.get("description", "AI agent for SQL database operations."),
|
|
38
|
+
)
|
|
39
|
+
DEFAULT_AGENT_SYSTEM_PROMPT = os.getenv(
|
|
40
|
+
"AGENT_SYSTEM_PROMPT",
|
|
41
|
+
meta.get("content") or build_system_prompt_from_workspace(),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
warnings.filterwarnings("ignore", message=".*urllib3.*")
|
|
45
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module="fastmcp")
|
|
46
|
+
|
|
47
|
+
print(f"{DEFAULT_AGENT_NAME} v{__version__}", file=sys.stderr)
|
|
48
|
+
parser = create_agent_parser()
|
|
49
|
+
args = parser.parse_args()
|
|
50
|
+
|
|
51
|
+
create_agent_server(
|
|
52
|
+
mcp_url=args.mcp_url,
|
|
53
|
+
mcp_config=args.mcp_config or "mcp_config.json",
|
|
54
|
+
host=args.host,
|
|
55
|
+
port=args.port,
|
|
56
|
+
provider=args.provider,
|
|
57
|
+
model_id=args.model_id,
|
|
58
|
+
router_model=args.model_id,
|
|
59
|
+
agent_model=args.model_id,
|
|
60
|
+
base_url=args.base_url,
|
|
61
|
+
api_key=args.api_key,
|
|
62
|
+
custom_skills_directory=args.custom_skills_directory,
|
|
63
|
+
enable_web_ui=args.web,
|
|
64
|
+
enable_otel=args.otel,
|
|
65
|
+
otel_endpoint=args.otel_endpoint,
|
|
66
|
+
otel_headers=args.otel_headers,
|
|
67
|
+
otel_public_key=args.otel_public_key,
|
|
68
|
+
otel_secret_key=args.otel_secret_key,
|
|
69
|
+
otel_protocol=args.otel_protocol,
|
|
70
|
+
debug=args.debug,
|
|
71
|
+
)
|
sql_mcp/api/__init__.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
"""SQLAlchemy 2.x Core facade for sql-mcp (CONCEPT:SQL-1.4).
|
|
2
|
+
|
|
3
|
+
``SqlApi`` is the single API surface the MCP tools call. It owns the named
|
|
4
|
+
connection registry (lazy ``Engine`` per connection), enforces the read-only
|
|
5
|
+
gate, the per-call row cap, and the per-call timeout, and returns bounded
|
|
6
|
+
result envelopes (records + column metadata + truncation flag). All SQL is
|
|
7
|
+
executed through ``sqlalchemy.text()`` with bound parameters — user values are
|
|
8
|
+
never interpolated into statement strings.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import concurrent.futures
|
|
12
|
+
from collections.abc import Callable, Mapping
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from sqlalchemy import create_engine, inspect, select, text
|
|
16
|
+
from sqlalchemy.engine import URL, Engine, make_url
|
|
17
|
+
from sqlalchemy.pool import StaticPool
|
|
18
|
+
from sqlalchemy.schema import CreateTable, MetaData, Table
|
|
19
|
+
|
|
20
|
+
from sql_mcp import auth
|
|
21
|
+
from sql_mcp.dialects import DialectSpec, dialect_for_url, require_driver
|
|
22
|
+
from sql_mcp.safety import assert_read_only, assert_single_statement
|
|
23
|
+
|
|
24
|
+
__version__ = "0.1.0"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SqlTimeoutError(TimeoutError):
|
|
28
|
+
"""Raised when a statement exceeds the per-call timeout."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class WritesDisabledError(PermissionError):
|
|
32
|
+
"""Raised when ``sql_execute`` is called while the server is read-only."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _is_memory_sqlite(url: URL) -> bool:
|
|
36
|
+
return url.get_backend_name() == "sqlite" and url.database in (None, "", ":memory:")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SqlApi:
|
|
40
|
+
"""Multi-connection SQL client over SQLAlchemy Core.
|
|
41
|
+
|
|
42
|
+
Parameters default from the environment (see :mod:`sql_mcp.auth`); tests
|
|
43
|
+
pass them explicitly. Engines are created lazily per named connection and
|
|
44
|
+
reused; in-memory SQLite gets a ``StaticPool`` so every call shares one
|
|
45
|
+
database.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
connections: Mapping[str, URL | str] | None = None,
|
|
51
|
+
allow_writes: bool | None = None,
|
|
52
|
+
max_rows: int | None = None,
|
|
53
|
+
timeout: float | None = None,
|
|
54
|
+
) -> None:
|
|
55
|
+
raw = dict(connections) if connections is not None else auth.load_connections()
|
|
56
|
+
self._connections = {
|
|
57
|
+
name: make_url(url) if isinstance(url, str) else url
|
|
58
|
+
for name, url in raw.items()
|
|
59
|
+
}
|
|
60
|
+
if not self._connections:
|
|
61
|
+
raise ValueError("At least one SQL connection must be configured.")
|
|
62
|
+
self.allow_writes = (
|
|
63
|
+
allow_writes if allow_writes is not None else auth.allow_writes()
|
|
64
|
+
)
|
|
65
|
+
self.max_rows = max_rows if max_rows is not None else auth.default_max_rows()
|
|
66
|
+
self.timeout = timeout if timeout is not None else auth.default_timeout()
|
|
67
|
+
self._engines: dict[str, Engine] = {}
|
|
68
|
+
|
|
69
|
+
# ------------------------------------------------------------------ #
|
|
70
|
+
# Connection registry
|
|
71
|
+
# ------------------------------------------------------------------ #
|
|
72
|
+
|
|
73
|
+
def connection_names(self) -> list[str]:
|
|
74
|
+
"""Names of all configured connections."""
|
|
75
|
+
return list(self._connections)
|
|
76
|
+
|
|
77
|
+
def default_connection(self) -> str:
|
|
78
|
+
"""The sole/first configured connection — used when none is named."""
|
|
79
|
+
return next(iter(self._connections))
|
|
80
|
+
|
|
81
|
+
def resolve_connection(self, connection: str | None = None) -> str:
|
|
82
|
+
"""Map an optional connection name to a configured one (or raise)."""
|
|
83
|
+
if not connection:
|
|
84
|
+
return self.default_connection()
|
|
85
|
+
if connection not in self._connections:
|
|
86
|
+
known = ", ".join(self._connections)
|
|
87
|
+
raise ValueError(f"Unknown connection {connection!r}. Known: {known}.")
|
|
88
|
+
return connection
|
|
89
|
+
|
|
90
|
+
def dialect_spec(self, connection: str | None = None) -> DialectSpec | None:
|
|
91
|
+
"""The registered :class:`DialectSpec` for a connection, if any."""
|
|
92
|
+
name = self.resolve_connection(connection)
|
|
93
|
+
return dialect_for_url(self._connections[name])
|
|
94
|
+
|
|
95
|
+
def engine(self, connection: str | None = None) -> Engine:
|
|
96
|
+
"""Lazily create (and cache) the Engine for a named connection."""
|
|
97
|
+
name = self.resolve_connection(connection)
|
|
98
|
+
eng = self._engines.get(name)
|
|
99
|
+
if eng is None:
|
|
100
|
+
url = self._connections[name]
|
|
101
|
+
spec = dialect_for_url(url)
|
|
102
|
+
if spec is not None:
|
|
103
|
+
require_driver(spec)
|
|
104
|
+
kwargs: dict[str, Any] = {"pool_pre_ping": True}
|
|
105
|
+
if _is_memory_sqlite(url):
|
|
106
|
+
kwargs = {
|
|
107
|
+
"poolclass": StaticPool,
|
|
108
|
+
"connect_args": {"check_same_thread": False},
|
|
109
|
+
}
|
|
110
|
+
eng = create_engine(url, **kwargs)
|
|
111
|
+
self._engines[name] = eng
|
|
112
|
+
return eng
|
|
113
|
+
|
|
114
|
+
def dispose(self) -> None:
|
|
115
|
+
"""Dispose all pooled engines."""
|
|
116
|
+
for eng in self._engines.values():
|
|
117
|
+
eng.dispose()
|
|
118
|
+
self._engines.clear()
|
|
119
|
+
|
|
120
|
+
# ------------------------------------------------------------------ #
|
|
121
|
+
# Internals
|
|
122
|
+
# ------------------------------------------------------------------ #
|
|
123
|
+
|
|
124
|
+
def _run_with_timeout(self, fn: Callable[[], Any], timeout: float | None) -> Any:
|
|
125
|
+
"""Run ``fn`` on a worker thread, bounded by ``timeout`` seconds."""
|
|
126
|
+
effective = self.timeout if timeout is None else float(timeout)
|
|
127
|
+
if effective <= 0:
|
|
128
|
+
return fn()
|
|
129
|
+
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
130
|
+
future = executor.submit(fn)
|
|
131
|
+
try:
|
|
132
|
+
return future.result(timeout=effective)
|
|
133
|
+
except concurrent.futures.TimeoutError as exc:
|
|
134
|
+
raise SqlTimeoutError(
|
|
135
|
+
f"Statement exceeded the {effective:g}s timeout and was abandoned."
|
|
136
|
+
) from exc
|
|
137
|
+
finally:
|
|
138
|
+
executor.shutdown(wait=False, cancel_futures=True)
|
|
139
|
+
|
|
140
|
+
def _effective_max_rows(self, max_rows: int | None) -> int:
|
|
141
|
+
if max_rows is None or max_rows <= 0:
|
|
142
|
+
return self.max_rows
|
|
143
|
+
return min(int(max_rows), self.max_rows)
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def _result_envelope(result: Any, cap: int) -> dict[str, Any]:
|
|
147
|
+
"""Fetch up to ``cap`` rows and describe columns (CONCEPT:SQL-1.4)."""
|
|
148
|
+
columns = list(result.keys())
|
|
149
|
+
fetched = result.fetchmany(cap + 1)
|
|
150
|
+
truncated = len(fetched) > cap
|
|
151
|
+
rows = [dict(zip(columns, row, strict=False)) for row in fetched[:cap]]
|
|
152
|
+
column_meta = [
|
|
153
|
+
{
|
|
154
|
+
"name": col,
|
|
155
|
+
"type": type(rows[0][col]).__name__ if rows else "unknown",
|
|
156
|
+
}
|
|
157
|
+
for col in columns
|
|
158
|
+
]
|
|
159
|
+
return {
|
|
160
|
+
"columns": column_meta,
|
|
161
|
+
"rows": rows,
|
|
162
|
+
"row_count": len(rows),
|
|
163
|
+
"truncated": truncated,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
# ------------------------------------------------------------------ #
|
|
167
|
+
# Query (read-only)
|
|
168
|
+
# ------------------------------------------------------------------ #
|
|
169
|
+
|
|
170
|
+
def query(
|
|
171
|
+
self,
|
|
172
|
+
sql: str,
|
|
173
|
+
params: Mapping[str, Any] | None = None,
|
|
174
|
+
connection: str | None = None,
|
|
175
|
+
max_rows: int | None = None,
|
|
176
|
+
timeout: float | None = None,
|
|
177
|
+
) -> dict[str, Any]:
|
|
178
|
+
"""Execute a read-only SELECT/CTE with bound parameters.
|
|
179
|
+
|
|
180
|
+
Enforces the read-only gate, clamps ``max_rows`` to the server cap,
|
|
181
|
+
and bounds execution time. Returns ``{"columns", "rows", "row_count",
|
|
182
|
+
"truncated"}``.
|
|
183
|
+
"""
|
|
184
|
+
assert_read_only(sql)
|
|
185
|
+
cap = self._effective_max_rows(max_rows)
|
|
186
|
+
eng = self.engine(connection)
|
|
187
|
+
|
|
188
|
+
def run() -> dict[str, Any]:
|
|
189
|
+
with eng.connect() as conn:
|
|
190
|
+
result = conn.execute(text(sql), dict(params or {}))
|
|
191
|
+
return self._result_envelope(result, cap)
|
|
192
|
+
|
|
193
|
+
return self._run_with_timeout(run, timeout)
|
|
194
|
+
|
|
195
|
+
def explain(
|
|
196
|
+
self,
|
|
197
|
+
sql: str,
|
|
198
|
+
params: Mapping[str, Any] | None = None,
|
|
199
|
+
connection: str | None = None,
|
|
200
|
+
timeout: float | None = None,
|
|
201
|
+
) -> dict[str, Any]:
|
|
202
|
+
"""Return the dialect's query plan for a read-only statement."""
|
|
203
|
+
assert_read_only(sql)
|
|
204
|
+
name = self.resolve_connection(connection)
|
|
205
|
+
spec = self.dialect_spec(name)
|
|
206
|
+
if spec is None or spec.explain_prefix is None:
|
|
207
|
+
dialect = spec.name if spec else self._connections[name].get_backend_name()
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"EXPLAIN is not supported for dialect {dialect!r} "
|
|
210
|
+
"(MSSQL uses SET SHOWPLAN, which needs a dedicated session)."
|
|
211
|
+
)
|
|
212
|
+
plan_sql = " ".join((spec.explain_prefix, sql))
|
|
213
|
+
eng = self.engine(name)
|
|
214
|
+
|
|
215
|
+
def run() -> dict[str, Any]:
|
|
216
|
+
with eng.connect() as conn:
|
|
217
|
+
result = conn.execute(text(plan_sql), dict(params or {}))
|
|
218
|
+
if not result.returns_rows:
|
|
219
|
+
return {
|
|
220
|
+
"columns": [],
|
|
221
|
+
"rows": [],
|
|
222
|
+
"row_count": 0,
|
|
223
|
+
"truncated": False,
|
|
224
|
+
}
|
|
225
|
+
return self._result_envelope(result, self.max_rows)
|
|
226
|
+
|
|
227
|
+
return self._run_with_timeout(run, timeout)
|
|
228
|
+
|
|
229
|
+
# ------------------------------------------------------------------ #
|
|
230
|
+
# Execute (writes, gated)
|
|
231
|
+
# ------------------------------------------------------------------ #
|
|
232
|
+
|
|
233
|
+
def _assert_writes_allowed(self) -> None:
|
|
234
|
+
if not self.allow_writes:
|
|
235
|
+
raise WritesDisabledError(
|
|
236
|
+
"Writes are disabled: the server is read-only by default. "
|
|
237
|
+
"Start it with SQL_ALLOW_WRITES=True to enable sql_execute."
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def execute(
|
|
241
|
+
self,
|
|
242
|
+
sql: str,
|
|
243
|
+
params: Mapping[str, Any] | list[Mapping[str, Any]] | None = None,
|
|
244
|
+
connection: str | None = None,
|
|
245
|
+
timeout: float | None = None,
|
|
246
|
+
) -> dict[str, Any]:
|
|
247
|
+
"""Execute one DML/DDL statement in a transaction (writes gate applies).
|
|
248
|
+
|
|
249
|
+
``params`` may be a mapping (single execution) or a list of mappings
|
|
250
|
+
(``executemany``). Returns the affected-row count.
|
|
251
|
+
"""
|
|
252
|
+
self._assert_writes_allowed()
|
|
253
|
+
assert_single_statement(sql)
|
|
254
|
+
eng = self.engine(connection)
|
|
255
|
+
|
|
256
|
+
def run() -> dict[str, Any]:
|
|
257
|
+
with eng.begin() as conn:
|
|
258
|
+
result = conn.execute(text(sql), params or {})
|
|
259
|
+
return {"rowcount": result.rowcount}
|
|
260
|
+
|
|
261
|
+
return self._run_with_timeout(run, timeout)
|
|
262
|
+
|
|
263
|
+
def execute_script(
|
|
264
|
+
self,
|
|
265
|
+
statements: list[str],
|
|
266
|
+
connection: str | None = None,
|
|
267
|
+
timeout: float | None = None,
|
|
268
|
+
) -> dict[str, Any]:
|
|
269
|
+
"""Run several statements in ONE transaction (all-or-nothing).
|
|
270
|
+
|
|
271
|
+
Any failure rolls back every prior statement in the list.
|
|
272
|
+
"""
|
|
273
|
+
self._assert_writes_allowed()
|
|
274
|
+
if not statements:
|
|
275
|
+
raise ValueError("'statements' must be a non-empty list of SQL strings.")
|
|
276
|
+
for stmt in statements:
|
|
277
|
+
assert_single_statement(stmt)
|
|
278
|
+
eng = self.engine(connection)
|
|
279
|
+
|
|
280
|
+
def run() -> dict[str, Any]:
|
|
281
|
+
rowcounts = []
|
|
282
|
+
with eng.begin() as conn:
|
|
283
|
+
for stmt in statements:
|
|
284
|
+
result = conn.execute(text(stmt))
|
|
285
|
+
rowcounts.append(result.rowcount)
|
|
286
|
+
return {"statements": len(statements), "rowcounts": rowcounts}
|
|
287
|
+
|
|
288
|
+
return self._run_with_timeout(run, timeout)
|
|
289
|
+
|
|
290
|
+
# ------------------------------------------------------------------ #
|
|
291
|
+
# Schema reflection
|
|
292
|
+
# ------------------------------------------------------------------ #
|
|
293
|
+
|
|
294
|
+
def list_schemas(self, connection: str | None = None) -> list[str]:
|
|
295
|
+
"""List schema names."""
|
|
296
|
+
return list(inspect(self.engine(connection)).get_schema_names())
|
|
297
|
+
|
|
298
|
+
def list_tables(
|
|
299
|
+
self, schema: str | None = None, connection: str | None = None
|
|
300
|
+
) -> list[str]:
|
|
301
|
+
"""List table names (optionally within a schema)."""
|
|
302
|
+
return list(inspect(self.engine(connection)).get_table_names(schema=schema))
|
|
303
|
+
|
|
304
|
+
def list_views(
|
|
305
|
+
self, schema: str | None = None, connection: str | None = None
|
|
306
|
+
) -> list[str]:
|
|
307
|
+
"""List view names (optionally within a schema)."""
|
|
308
|
+
return list(inspect(self.engine(connection)).get_view_names(schema=schema))
|
|
309
|
+
|
|
310
|
+
def list_columns(
|
|
311
|
+
self, table: str, schema: str | None = None, connection: str | None = None
|
|
312
|
+
) -> list[dict[str, Any]]:
|
|
313
|
+
"""Describe a table's columns: name, type, nullable, default."""
|
|
314
|
+
cols = inspect(self.engine(connection)).get_columns(table, schema=schema)
|
|
315
|
+
return [
|
|
316
|
+
{
|
|
317
|
+
"name": col["name"],
|
|
318
|
+
"type": str(col["type"]),
|
|
319
|
+
"nullable": bool(col.get("nullable", True)),
|
|
320
|
+
"default": col.get("default"),
|
|
321
|
+
"primary_key": bool(col.get("primary_key", False)),
|
|
322
|
+
}
|
|
323
|
+
for col in cols
|
|
324
|
+
]
|
|
325
|
+
|
|
326
|
+
def list_indexes(
|
|
327
|
+
self, table: str, schema: str | None = None, connection: str | None = None
|
|
328
|
+
) -> list[dict[str, Any]]:
|
|
329
|
+
"""List a table's indexes (name, columns, uniqueness)."""
|
|
330
|
+
idx = inspect(self.engine(connection)).get_indexes(table, schema=schema)
|
|
331
|
+
return [
|
|
332
|
+
{
|
|
333
|
+
"name": entry.get("name"),
|
|
334
|
+
"columns": list(entry.get("column_names") or []),
|
|
335
|
+
"unique": bool(entry.get("unique", False)),
|
|
336
|
+
}
|
|
337
|
+
for entry in idx
|
|
338
|
+
]
|
|
339
|
+
|
|
340
|
+
def list_foreign_keys(
|
|
341
|
+
self, table: str, schema: str | None = None, connection: str | None = None
|
|
342
|
+
) -> list[dict[str, Any]]:
|
|
343
|
+
"""List a table's foreign keys (columns -> referred table/columns)."""
|
|
344
|
+
fks = inspect(self.engine(connection)).get_foreign_keys(table, schema=schema)
|
|
345
|
+
return [
|
|
346
|
+
{
|
|
347
|
+
"name": entry.get("name"),
|
|
348
|
+
"columns": list(entry.get("constrained_columns") or []),
|
|
349
|
+
"referred_table": entry.get("referred_table"),
|
|
350
|
+
"referred_columns": list(entry.get("referred_columns") or []),
|
|
351
|
+
}
|
|
352
|
+
for entry in fks
|
|
353
|
+
]
|
|
354
|
+
|
|
355
|
+
def table_ddl(
|
|
356
|
+
self, table: str, schema: str | None = None, connection: str | None = None
|
|
357
|
+
) -> str:
|
|
358
|
+
"""Reflect a table and render its CREATE TABLE DDL for this dialect."""
|
|
359
|
+
eng = self.engine(connection)
|
|
360
|
+
metadata = MetaData()
|
|
361
|
+
reflected = Table(table, metadata, schema=schema, autoload_with=eng)
|
|
362
|
+
return str(CreateTable(reflected).compile(eng)).strip()
|
|
363
|
+
|
|
364
|
+
def sample_rows(
|
|
365
|
+
self,
|
|
366
|
+
table: str,
|
|
367
|
+
schema: str | None = None,
|
|
368
|
+
limit: int = 10,
|
|
369
|
+
connection: str | None = None,
|
|
370
|
+
timeout: float | None = None,
|
|
371
|
+
) -> dict[str, Any]:
|
|
372
|
+
"""Return up to ``limit`` rows from a table (cap still applies).
|
|
373
|
+
|
|
374
|
+
Built with SQLAlchemy Core ``select()`` on the reflected table — the
|
|
375
|
+
identifier is quoted by SQLAlchemy, never interpolated by hand.
|
|
376
|
+
"""
|
|
377
|
+
eng = self.engine(connection)
|
|
378
|
+
cap = self._effective_max_rows(limit)
|
|
379
|
+
metadata = MetaData()
|
|
380
|
+
reflected = Table(table, metadata, schema=schema, autoload_with=eng)
|
|
381
|
+
|
|
382
|
+
def run() -> dict[str, Any]:
|
|
383
|
+
with eng.connect() as conn:
|
|
384
|
+
# Fetch one extra row so the envelope can flag truncation.
|
|
385
|
+
result = conn.execute(select(reflected).limit(cap + 1))
|
|
386
|
+
return self._result_envelope(result, cap)
|
|
387
|
+
|
|
388
|
+
return self._run_with_timeout(run, timeout)
|
|
389
|
+
|
|
390
|
+
# ------------------------------------------------------------------ #
|
|
391
|
+
# Admin
|
|
392
|
+
# ------------------------------------------------------------------ #
|
|
393
|
+
|
|
394
|
+
def ping(self, connection: str | None = None) -> dict[str, Any]:
|
|
395
|
+
"""Connection test: ``SELECT 1`` round-trip with latency."""
|
|
396
|
+
import time
|
|
397
|
+
|
|
398
|
+
eng = self.engine(connection)
|
|
399
|
+
started = time.monotonic()
|
|
400
|
+
|
|
401
|
+
def run() -> None:
|
|
402
|
+
with eng.connect() as conn:
|
|
403
|
+
conn.execute(text("SELECT 1"))
|
|
404
|
+
|
|
405
|
+
self._run_with_timeout(run, None)
|
|
406
|
+
return {
|
|
407
|
+
"connection": self.resolve_connection(connection),
|
|
408
|
+
"ok": True,
|
|
409
|
+
"latency_ms": round((time.monotonic() - started) * 1000, 2),
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
def server_version(self, connection: str | None = None) -> dict[str, Any]:
|
|
413
|
+
"""Report the server version (dialect SQL, else SQLAlchemy's probe)."""
|
|
414
|
+
name = self.resolve_connection(connection)
|
|
415
|
+
spec = self.dialect_spec(name)
|
|
416
|
+
eng = self.engine(name)
|
|
417
|
+
|
|
418
|
+
def run() -> dict[str, Any]:
|
|
419
|
+
with eng.connect() as conn:
|
|
420
|
+
version: str | None = None
|
|
421
|
+
if spec is not None and spec.version_sql:
|
|
422
|
+
version = str(conn.execute(text(spec.version_sql)).scalar())
|
|
423
|
+
info = getattr(conn.dialect, "server_version_info", None)
|
|
424
|
+
return {
|
|
425
|
+
"connection": name,
|
|
426
|
+
"dialect": eng.dialect.name,
|
|
427
|
+
"version": version
|
|
428
|
+
or (".".join(str(part) for part in info) if info else "unknown"),
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
return self._run_with_timeout(run, None)
|
|
432
|
+
|
|
433
|
+
def active_connections(self, connection: str | None = None) -> dict[str, Any]:
|
|
434
|
+
"""List active server sessions where the dialect supports it."""
|
|
435
|
+
name = self.resolve_connection(connection)
|
|
436
|
+
spec = self.dialect_spec(name)
|
|
437
|
+
if spec is None or spec.active_connections_sql is None:
|
|
438
|
+
dialect = spec.name if spec else self._connections[name].get_backend_name()
|
|
439
|
+
return {
|
|
440
|
+
"connection": name,
|
|
441
|
+
"supported": False,
|
|
442
|
+
"detail": f"Dialect {dialect!r} has no active-session view.",
|
|
443
|
+
}
|
|
444
|
+
eng = self.engine(name)
|
|
445
|
+
active_sql = spec.active_connections_sql
|
|
446
|
+
|
|
447
|
+
def run() -> dict[str, Any]:
|
|
448
|
+
with eng.connect() as conn:
|
|
449
|
+
result = conn.execute(text(active_sql))
|
|
450
|
+
envelope = self._result_envelope(result, self.max_rows)
|
|
451
|
+
envelope.update({"connection": name, "supported": True})
|
|
452
|
+
return envelope
|
|
453
|
+
|
|
454
|
+
return self._run_with_timeout(run, None)
|
|
455
|
+
|
|
456
|
+
def describe_connections(self) -> list[dict[str, Any]]:
|
|
457
|
+
"""Describe configured connections with passwords redacted."""
|
|
458
|
+
described = []
|
|
459
|
+
for name, url in self._connections.items():
|
|
460
|
+
spec = dialect_for_url(url)
|
|
461
|
+
described.append(
|
|
462
|
+
{
|
|
463
|
+
"name": name,
|
|
464
|
+
"url": url.render_as_string(hide_password=True),
|
|
465
|
+
"dialect": spec.name if spec else url.get_backend_name(),
|
|
466
|
+
"default": name == self.default_connection(),
|
|
467
|
+
}
|
|
468
|
+
)
|
|
469
|
+
return described
|