sqlserver-semantic-mcp 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlserver_semantic_mcp/__init__.py +1 -0
- sqlserver_semantic_mcp/config.py +78 -0
- sqlserver_semantic_mcp/domain/__init__.py +0 -0
- sqlserver_semantic_mcp/domain/enums.py +48 -0
- sqlserver_semantic_mcp/domain/models/__init__.py +0 -0
- sqlserver_semantic_mcp/domain/models/column.py +14 -0
- sqlserver_semantic_mcp/domain/models/object.py +13 -0
- sqlserver_semantic_mcp/domain/models/relationship.py +11 -0
- sqlserver_semantic_mcp/domain/models/table.py +29 -0
- sqlserver_semantic_mcp/infrastructure/__init__.py +0 -0
- sqlserver_semantic_mcp/infrastructure/background.py +59 -0
- sqlserver_semantic_mcp/infrastructure/cache/__init__.py +0 -0
- sqlserver_semantic_mcp/infrastructure/cache/semantic.py +132 -0
- sqlserver_semantic_mcp/infrastructure/cache/store.py +152 -0
- sqlserver_semantic_mcp/infrastructure/cache/structural.py +203 -0
- sqlserver_semantic_mcp/infrastructure/connection.py +78 -0
- sqlserver_semantic_mcp/infrastructure/queries/__init__.py +0 -0
- sqlserver_semantic_mcp/infrastructure/queries/comment_queries.py +18 -0
- sqlserver_semantic_mcp/infrastructure/queries/metadata_queries.py +70 -0
- sqlserver_semantic_mcp/infrastructure/queries/object_queries.py +15 -0
- sqlserver_semantic_mcp/main.py +90 -0
- sqlserver_semantic_mcp/policy/__init__.py +0 -0
- sqlserver_semantic_mcp/policy/analyzer.py +194 -0
- sqlserver_semantic_mcp/policy/enforcer.py +104 -0
- sqlserver_semantic_mcp/policy/intents/__init__.py +16 -0
- sqlserver_semantic_mcp/policy/intents/ast_analyzer.py +24 -0
- sqlserver_semantic_mcp/policy/intents/base.py +17 -0
- sqlserver_semantic_mcp/policy/intents/regex_analyzer.py +11 -0
- sqlserver_semantic_mcp/policy/intents/router.py +21 -0
- sqlserver_semantic_mcp/policy/loader.py +90 -0
- sqlserver_semantic_mcp/policy/models.py +43 -0
- sqlserver_semantic_mcp/server/__init__.py +0 -0
- sqlserver_semantic_mcp/server/app.py +125 -0
- sqlserver_semantic_mcp/server/compact.py +74 -0
- sqlserver_semantic_mcp/server/prompts/__init__.py +5 -0
- sqlserver_semantic_mcp/server/prompts/analysis.py +56 -0
- sqlserver_semantic_mcp/server/prompts/discovery.py +55 -0
- sqlserver_semantic_mcp/server/prompts/execution.py +64 -0
- sqlserver_semantic_mcp/server/prompts/registry.py +41 -0
- sqlserver_semantic_mcp/server/resources/__init__.py +1 -0
- sqlserver_semantic_mcp/server/resources/schema.py +144 -0
- sqlserver_semantic_mcp/server/tools/__init__.py +42 -0
- sqlserver_semantic_mcp/server/tools/cache.py +24 -0
- sqlserver_semantic_mcp/server/tools/metadata.py +167 -0
- sqlserver_semantic_mcp/server/tools/metrics.py +44 -0
- sqlserver_semantic_mcp/server/tools/object_tool.py +113 -0
- sqlserver_semantic_mcp/server/tools/policy.py +48 -0
- sqlserver_semantic_mcp/server/tools/query.py +159 -0
- sqlserver_semantic_mcp/server/tools/relationship.py +104 -0
- sqlserver_semantic_mcp/server/tools/semantic.py +112 -0
- sqlserver_semantic_mcp/server/tools/shape.py +204 -0
- sqlserver_semantic_mcp/server/tools/workflow.py +307 -0
- sqlserver_semantic_mcp/services/__init__.py +0 -0
- sqlserver_semantic_mcp/services/metadata_service.py +173 -0
- sqlserver_semantic_mcp/services/metrics_service.py +124 -0
- sqlserver_semantic_mcp/services/object_service.py +187 -0
- sqlserver_semantic_mcp/services/policy_service.py +59 -0
- sqlserver_semantic_mcp/services/query_service.py +321 -0
- sqlserver_semantic_mcp/services/relationship_service.py +160 -0
- sqlserver_semantic_mcp/services/semantic_service.py +277 -0
- sqlserver_semantic_mcp/workflows/__init__.py +26 -0
- sqlserver_semantic_mcp/workflows/bundle.py +157 -0
- sqlserver_semantic_mcp/workflows/contracts.py +64 -0
- sqlserver_semantic_mcp/workflows/discovery_flow.py +116 -0
- sqlserver_semantic_mcp/workflows/facade.py +117 -0
- sqlserver_semantic_mcp/workflows/query_flow.py +120 -0
- sqlserver_semantic_mcp/workflows/recommendations.py +161 -0
- sqlserver_semantic_mcp/workflows/router.py +59 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/METADATA +679 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/RECORD +74 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/WHEEL +5 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/entry_points.txt +2 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/licenses/LICENSE +21 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from ..config import Config, get_config
|
|
6
|
+
from ..domain.enums import SqlOperation
|
|
7
|
+
from ..infrastructure.cache.semantic import (
|
|
8
|
+
get_object_definition, upsert_object_definition,
|
|
9
|
+
)
|
|
10
|
+
from ..infrastructure.cache.structural import read_schema_version
|
|
11
|
+
from ..infrastructure.connection import fetch_one, fetch_all
|
|
12
|
+
from ..infrastructure.queries.object_queries import (
|
|
13
|
+
GET_OBJECT_DEFINITION, GET_OBJECT_DEPENDENCIES,
|
|
14
|
+
)
|
|
15
|
+
from ..policy.analyzer import (
|
|
16
|
+
_strip_comments, _split_statements, _detect_operation, _IDENT,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
_WRITE_OPS = {
|
|
23
|
+
SqlOperation.UPDATE, SqlOperation.INSERT, SqlOperation.DELETE,
|
|
24
|
+
SqlOperation.MERGE, SqlOperation.TRUNCATE,
|
|
25
|
+
SqlOperation.DROP, SqlOperation.ALTER, SqlOperation.CREATE,
|
|
26
|
+
SqlOperation.EXEC, SqlOperation.EXECUTE,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _write_target(sql: str, operation: SqlOperation) -> Optional[str]:
|
|
31
|
+
patterns = {
|
|
32
|
+
SqlOperation.UPDATE: rf"\bUPDATE\s+({_IDENT})",
|
|
33
|
+
SqlOperation.INSERT: rf"\bINTO\s+({_IDENT})",
|
|
34
|
+
SqlOperation.DELETE: rf"\bDELETE\s+(?:FROM\s+)?({_IDENT})",
|
|
35
|
+
SqlOperation.MERGE: rf"\bMERGE\s+(?:INTO\s+)?({_IDENT})",
|
|
36
|
+
SqlOperation.TRUNCATE: rf"\bTRUNCATE\s+TABLE\s+({_IDENT})",
|
|
37
|
+
}
|
|
38
|
+
pat = patterns.get(operation)
|
|
39
|
+
if not pat:
|
|
40
|
+
return None
|
|
41
|
+
m = re.search(pat, sql, re.IGNORECASE)
|
|
42
|
+
return m.group(1) if m else None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _from_join_sources(sql: str) -> list[str]:
|
|
46
|
+
tables: list[str] = []
|
|
47
|
+
tables.extend(re.findall(rf"\bFROM\s+({_IDENT})", sql, re.IGNORECASE))
|
|
48
|
+
tables.extend(re.findall(rf"\bJOIN\s+({_IDENT})", sql, re.IGNORECASE))
|
|
49
|
+
return tables
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
_WRITE_PATTERNS = [
|
|
53
|
+
rf"\bUPDATE\s+({_IDENT})",
|
|
54
|
+
rf"\bINSERT\s+INTO\s+({_IDENT})",
|
|
55
|
+
rf"\bDELETE\s+FROM\s+({_IDENT})",
|
|
56
|
+
rf"\bMERGE\s+(?:INTO\s+)?({_IDENT})",
|
|
57
|
+
rf"\bTRUNCATE\s+TABLE\s+({_IDENT})",
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def split_read_write(sql: str) -> tuple[list[str], list[str]]:
|
|
62
|
+
"""Split a SQL body (e.g. a PROCEDURE definition) into (read_tables, write_tables).
|
|
63
|
+
|
|
64
|
+
Regex-based. Scans the entire SQL for write-operation patterns (UPDATE/INSERT/
|
|
65
|
+
DELETE/MERGE/TRUNCATE TABLE) and for read-source patterns (FROM/JOIN).
|
|
66
|
+
Write targets are excluded from reads even if they also appear as FROM aliases
|
|
67
|
+
in the same statement (write-intent wins).
|
|
68
|
+
|
|
69
|
+
Known limitations: CTE names may appear as reads; dynamic SQL is invisible.
|
|
70
|
+
Returns ([], []) on empty input.
|
|
71
|
+
"""
|
|
72
|
+
if not sql or not sql.strip():
|
|
73
|
+
return [], []
|
|
74
|
+
|
|
75
|
+
clean = _strip_comments(sql)
|
|
76
|
+
|
|
77
|
+
writes: list[str] = []
|
|
78
|
+
for pat in _WRITE_PATTERNS:
|
|
79
|
+
writes.extend(re.findall(pat, clean, re.IGNORECASE))
|
|
80
|
+
|
|
81
|
+
# Read sources = FROM / JOIN, excluding DELETE FROM target
|
|
82
|
+
# Strip DELETE FROM fragments so they don't double-count
|
|
83
|
+
read_scan = re.sub(
|
|
84
|
+
rf"\bDELETE\s+FROM\s+{_IDENT}", "", clean, flags=re.IGNORECASE,
|
|
85
|
+
)
|
|
86
|
+
reads: list[str] = []
|
|
87
|
+
reads.extend(re.findall(rf"\bFROM\s+({_IDENT})", read_scan, re.IGNORECASE))
|
|
88
|
+
reads.extend(re.findall(rf"\bJOIN\s+({_IDENT})", read_scan, re.IGNORECASE))
|
|
89
|
+
|
|
90
|
+
# Dedup preserving order
|
|
91
|
+
def _dedup(items: list[str]) -> list[str]:
|
|
92
|
+
seen: set[str] = set()
|
|
93
|
+
out: list[str] = []
|
|
94
|
+
for t in items:
|
|
95
|
+
k = t.lower()
|
|
96
|
+
if k not in seen:
|
|
97
|
+
seen.add(k)
|
|
98
|
+
out.append(t)
|
|
99
|
+
return out
|
|
100
|
+
|
|
101
|
+
writes_d = _dedup(writes)
|
|
102
|
+
# Reads: dedup, then remove any table that is also in writes (write-intent wins)
|
|
103
|
+
write_keys = {w.lower() for w in writes_d}
|
|
104
|
+
reads_d = [r for r in _dedup(reads) if r.lower() not in write_keys]
|
|
105
|
+
|
|
106
|
+
return reads_d, writes_d
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _augment_read_write(obj: dict) -> dict:
|
|
110
|
+
"""Add read_tables/write_tables derived from the cached definition."""
|
|
111
|
+
if not obj:
|
|
112
|
+
return obj
|
|
113
|
+
definition = obj.get("definition")
|
|
114
|
+
if isinstance(definition, str) and definition:
|
|
115
|
+
try:
|
|
116
|
+
reads, writes = split_read_write(definition)
|
|
117
|
+
except Exception:
|
|
118
|
+
logger.exception("split_read_write failed; falling back")
|
|
119
|
+
reads, writes = obj.get("dependencies", []) or [], []
|
|
120
|
+
out = dict(obj)
|
|
121
|
+
out["read_tables"] = reads
|
|
122
|
+
out["write_tables"] = writes
|
|
123
|
+
# Legacy: affected_tables aliases write_tables (name now matches intent)
|
|
124
|
+
out["affected_tables"] = writes
|
|
125
|
+
return out
|
|
126
|
+
return obj
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
async def describe_object(
|
|
130
|
+
schema: str, object_name: str, object_type: str,
|
|
131
|
+
cfg: Optional[Config] = None,
|
|
132
|
+
) -> dict:
|
|
133
|
+
cfg = cfg or get_config()
|
|
134
|
+
db = cfg.mssql_database
|
|
135
|
+
ver = await read_schema_version(cfg.cache_path, db)
|
|
136
|
+
object_hash = ver["object_hash"] if ver else ""
|
|
137
|
+
|
|
138
|
+
cached = await get_object_definition(
|
|
139
|
+
cfg.cache_path, db, schema, object_name, object_type,
|
|
140
|
+
)
|
|
141
|
+
if cached and cached["status"] == "ready" \
|
|
142
|
+
and cached.get("object_hash") == object_hash:
|
|
143
|
+
return _augment_read_write(cached)
|
|
144
|
+
|
|
145
|
+
qualified = f"{schema}.{object_name}"
|
|
146
|
+
try:
|
|
147
|
+
def_row = fetch_one(cfg, GET_OBJECT_DEFINITION, (qualified,))
|
|
148
|
+
definition = def_row[0] if def_row and def_row[0] else None
|
|
149
|
+
dep_rows = fetch_all(cfg, GET_OBJECT_DEPENDENCIES, (qualified,))
|
|
150
|
+
dependencies = [f"{r[0]}.{r[1]}" for r in dep_rows if r[0]]
|
|
151
|
+
affected = [
|
|
152
|
+
f"{r[0]}.{r[1]}" for r in dep_rows
|
|
153
|
+
if r[2] and "TABLE" in str(r[2]).upper()
|
|
154
|
+
]
|
|
155
|
+
await upsert_object_definition(
|
|
156
|
+
cfg.cache_path, db, schema, object_name, object_type,
|
|
157
|
+
object_hash=object_hash, status="ready",
|
|
158
|
+
definition=definition, dependencies=dependencies,
|
|
159
|
+
affected_tables=affected,
|
|
160
|
+
)
|
|
161
|
+
return _augment_read_write({
|
|
162
|
+
"database_name": db,
|
|
163
|
+
"schema": schema,
|
|
164
|
+
"object_name": object_name,
|
|
165
|
+
"object_type": object_type,
|
|
166
|
+
"object_hash": object_hash,
|
|
167
|
+
"status": "ready",
|
|
168
|
+
"definition": definition,
|
|
169
|
+
"dependencies": dependencies,
|
|
170
|
+
"affected_tables": affected,
|
|
171
|
+
})
|
|
172
|
+
except Exception as e:
|
|
173
|
+
logger.exception("describe_object failed")
|
|
174
|
+
await upsert_object_definition(
|
|
175
|
+
cfg.cache_path, db, schema, object_name, object_type,
|
|
176
|
+
object_hash=object_hash, status="error",
|
|
177
|
+
error_message=str(e),
|
|
178
|
+
)
|
|
179
|
+
return {"status": "error", "error_message": str(e)}
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
async def trace_dependencies(
|
|
183
|
+
schema: str, object_name: str, object_type: str,
|
|
184
|
+
cfg: Optional[Config] = None,
|
|
185
|
+
) -> list[str]:
|
|
186
|
+
obj = await describe_object(schema, object_name, object_type, cfg)
|
|
187
|
+
return obj.get("dependencies", []) if obj else []
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from ..config import Config, get_config
|
|
4
|
+
from ..policy.analyzer import SqlIntent
|
|
5
|
+
from ..policy.enforcer import enforce
|
|
6
|
+
from ..policy.intents import get_analyzer
|
|
7
|
+
from ..policy.loader import load_active_policy
|
|
8
|
+
from ..policy.models import PolicyProfile
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PolicyService:
|
|
12
|
+
def __init__(self, cfg: Optional[Config] = None) -> None:
|
|
13
|
+
self._cfg = cfg or get_config()
|
|
14
|
+
self._policy: Optional[PolicyProfile] = None
|
|
15
|
+
self._analyzer = get_analyzer(self._cfg)
|
|
16
|
+
|
|
17
|
+
def load(self) -> None:
|
|
18
|
+
self._policy = load_active_policy(self._cfg)
|
|
19
|
+
|
|
20
|
+
def reload(self) -> None:
|
|
21
|
+
self.load()
|
|
22
|
+
self._analyzer = get_analyzer(self._cfg)
|
|
23
|
+
|
|
24
|
+
def current_policy(self) -> PolicyProfile:
|
|
25
|
+
if self._policy is None:
|
|
26
|
+
self.load()
|
|
27
|
+
assert self._policy is not None
|
|
28
|
+
return self._policy
|
|
29
|
+
|
|
30
|
+
def analyze(self, sql: str) -> SqlIntent:
|
|
31
|
+
return self._analyzer.analyze(sql)
|
|
32
|
+
|
|
33
|
+
def validate(self, sql: str, database: str = "") -> dict:
|
|
34
|
+
policy = self.current_policy()
|
|
35
|
+
intent = self._analyzer.analyze(sql)
|
|
36
|
+
result = enforce(intent, policy, database=database)
|
|
37
|
+
return {
|
|
38
|
+
"allowed": result.allowed,
|
|
39
|
+
"reason": result.reason,
|
|
40
|
+
"intent": intent_to_dict(intent),
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def intent_to_dict(intent: SqlIntent) -> dict:
|
|
45
|
+
return {
|
|
46
|
+
"primary_operation": intent.primary_operation.value,
|
|
47
|
+
"has_where_clause": intent.has_where_clause,
|
|
48
|
+
"has_top_clause": intent.has_top_clause,
|
|
49
|
+
"affected_tables": intent.affected_tables,
|
|
50
|
+
"risk_level": intent.risk_level.value,
|
|
51
|
+
"is_multi_statement": intent.is_multi_statement,
|
|
52
|
+
"statement_count": intent.statement_count,
|
|
53
|
+
"is_sql_like": intent.is_sql_like,
|
|
54
|
+
"confidence": intent.confidence,
|
|
55
|
+
"requires_discovery": intent.requires_discovery,
|
|
56
|
+
"has_unqualified_tables": intent.has_unqualified_tables,
|
|
57
|
+
"contains_dynamic_sql": intent.contains_dynamic_sql,
|
|
58
|
+
"contains_cte": intent.contains_cte,
|
|
59
|
+
}
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
"""Query service — validation / preview / execution.
|
|
2
|
+
|
|
3
|
+
v0.5 splits the old ``run_safe_query()`` into three explicit phases so
|
|
4
|
+
the workflow layer can route an agent's request down the shortest safe
|
|
5
|
+
path. ``run_safe_query()`` is kept as a thin wrapper over
|
|
6
|
+
``execute_query`` for backwards compatibility.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Any, Optional
|
|
13
|
+
|
|
14
|
+
from ..config import Config, get_config
|
|
15
|
+
from ..infrastructure.connection import open_connection
|
|
16
|
+
from ..policy.analyzer import SqlIntent
|
|
17
|
+
from .policy_service import PolicyService, intent_to_dict
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class QueryExecutionMode(str, Enum):
|
|
23
|
+
VALIDATE_ONLY = "validate_only"
|
|
24
|
+
DRY_RUN = "dry_run"
|
|
25
|
+
EXECUTE_IF_SAFE = "execute_if_safe"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AffectedRowsPolicyMode(str, Enum):
|
|
29
|
+
STRICT = "strict"
|
|
30
|
+
REPORT = "report"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ---- response_mode helpers --------------------------------------------------
|
|
34
|
+
|
|
35
|
+
_VALID_RESPONSE_MODES = {"summary", "rows", "sample", "count_only"}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _normalize_response_mode(value: Optional[str], default: str) -> str:
|
|
39
|
+
if value is None:
|
|
40
|
+
return default
|
|
41
|
+
if value not in _VALID_RESPONSE_MODES:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"invalid response_mode '{value}'; "
|
|
44
|
+
f"expected one of {sorted(_VALID_RESPONSE_MODES)}"
|
|
45
|
+
)
|
|
46
|
+
return value
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# ---- budget hint ------------------------------------------------------------
|
|
50
|
+
|
|
51
|
+
_BUDGET_SAMPLE_ROWS = {
|
|
52
|
+
"tiny": 3,
|
|
53
|
+
"low": 10,
|
|
54
|
+
"medium": 50,
|
|
55
|
+
"high": 200,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def sample_row_cap(budget: Optional[str]) -> int:
|
|
60
|
+
return _BUDGET_SAMPLE_ROWS.get(budget or "low", 10)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# ---- service ----------------------------------------------------------------
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class QueryService:
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
policy_service: PolicyService,
|
|
70
|
+
cfg: Optional[Config] = None,
|
|
71
|
+
) -> None:
|
|
72
|
+
self._policy = policy_service
|
|
73
|
+
self._cfg = cfg or get_config()
|
|
74
|
+
|
|
75
|
+
# ------------------------------------------------------------------ 1. validate
|
|
76
|
+
|
|
77
|
+
def validate(self, sql: str, database: str = "") -> dict:
|
|
78
|
+
"""Backwards-compatible validation façade."""
|
|
79
|
+
db = database or self._cfg.mssql_database
|
|
80
|
+
return self._policy.validate(sql, database=db)
|
|
81
|
+
|
|
82
|
+
def validate_query(self, sql: str, database: str = "") -> dict:
|
|
83
|
+
"""Return validation + intent, agent-envelope friendly."""
|
|
84
|
+
db = database or self._cfg.mssql_database
|
|
85
|
+
validation = self._policy.validate(sql, database=db)
|
|
86
|
+
intent = validation["intent"]
|
|
87
|
+
next_action = "execute" if validation["allowed"] else "revise_query"
|
|
88
|
+
return {
|
|
89
|
+
"kind": "query_validation",
|
|
90
|
+
"allowed": validation["allowed"],
|
|
91
|
+
"reason": validation["reason"],
|
|
92
|
+
"intent": intent,
|
|
93
|
+
"risk": intent["risk_level"],
|
|
94
|
+
"tables": intent["affected_tables"],
|
|
95
|
+
"next_action": next_action,
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# ------------------------------------------------------------------ 2. preview
|
|
99
|
+
|
|
100
|
+
def preview_query(
|
|
101
|
+
self,
|
|
102
|
+
sql: str,
|
|
103
|
+
*,
|
|
104
|
+
max_rows: Optional[int] = None,
|
|
105
|
+
database: str = "",
|
|
106
|
+
) -> dict:
|
|
107
|
+
"""Cheap dry-run: return what WOULD happen, without side effects."""
|
|
108
|
+
db = database or self._cfg.mssql_database
|
|
109
|
+
policy = self._policy.current_policy()
|
|
110
|
+
validation = self._policy.validate(sql, database=db)
|
|
111
|
+
intent = validation["intent"]
|
|
112
|
+
limit = max_rows or policy.constraints.max_rows_returned
|
|
113
|
+
|
|
114
|
+
return {
|
|
115
|
+
"kind": "query_preview",
|
|
116
|
+
"operation": intent["primary_operation"],
|
|
117
|
+
"tables": intent["affected_tables"],
|
|
118
|
+
"allowed": validation["allowed"],
|
|
119
|
+
"reason": validation["reason"],
|
|
120
|
+
"risk": intent["risk_level"],
|
|
121
|
+
"max_rows_applied": limit,
|
|
122
|
+
"max_rows_affected": policy.constraints.max_rows_affected,
|
|
123
|
+
"is_multi_statement": intent["is_multi_statement"],
|
|
124
|
+
"has_where_clause": intent["has_where_clause"],
|
|
125
|
+
"has_unqualified_tables": intent["has_unqualified_tables"],
|
|
126
|
+
"contains_dynamic_sql": intent["contains_dynamic_sql"],
|
|
127
|
+
"next_action": "execute" if validation["allowed"] else "revise_query",
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
# ------------------------------------------------------------------ 3. execute
|
|
131
|
+
|
|
132
|
+
def execute_query(
|
|
133
|
+
self,
|
|
134
|
+
sql: str,
|
|
135
|
+
*,
|
|
136
|
+
max_rows: Optional[int] = None,
|
|
137
|
+
response_mode: Optional[str] = None,
|
|
138
|
+
token_budget_hint: Optional[str] = None,
|
|
139
|
+
affected_rows_policy: Optional[str] = None,
|
|
140
|
+
database: str = "",
|
|
141
|
+
) -> dict:
|
|
142
|
+
"""Execute SQL after policy validation.
|
|
143
|
+
|
|
144
|
+
response_mode:
|
|
145
|
+
summary — columns + row_count only
|
|
146
|
+
rows — columns + rows (default when op=SELECT)
|
|
147
|
+
sample — columns + first N rows (N = budget-derived)
|
|
148
|
+
count_only — row_count only
|
|
149
|
+
"""
|
|
150
|
+
mode = _normalize_response_mode(
|
|
151
|
+
response_mode, self._cfg.default_response_mode,
|
|
152
|
+
)
|
|
153
|
+
budget = token_budget_hint or self._cfg.default_token_budget_hint
|
|
154
|
+
|
|
155
|
+
strict_cap = self._cfg.strict_rows_affected_cap
|
|
156
|
+
if affected_rows_policy is not None:
|
|
157
|
+
strict_cap = affected_rows_policy == "strict"
|
|
158
|
+
|
|
159
|
+
db = database or self._cfg.mssql_database
|
|
160
|
+
policy = self._policy.current_policy()
|
|
161
|
+
limit = max_rows or policy.constraints.max_rows_returned
|
|
162
|
+
|
|
163
|
+
validation = self._policy.validate(sql, database=db)
|
|
164
|
+
if not validation["allowed"]:
|
|
165
|
+
return {
|
|
166
|
+
"executed": False,
|
|
167
|
+
"validation": validation,
|
|
168
|
+
"error": validation["reason"],
|
|
169
|
+
"next_action": "revise_query",
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
op = validation["intent"]["primary_operation"]
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
with open_connection(self._cfg) as conn:
|
|
176
|
+
cursor = conn.cursor()
|
|
177
|
+
try:
|
|
178
|
+
cursor.execute(sql)
|
|
179
|
+
|
|
180
|
+
if op == "SELECT":
|
|
181
|
+
return self._shape_select(
|
|
182
|
+
cursor, limit, mode, budget, validation,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return self._shape_non_select(
|
|
186
|
+
cursor, conn, policy.constraints.max_rows_affected,
|
|
187
|
+
strict_cap, validation,
|
|
188
|
+
)
|
|
189
|
+
finally:
|
|
190
|
+
try:
|
|
191
|
+
cursor.close()
|
|
192
|
+
except Exception:
|
|
193
|
+
logger.warning("Failed to close cursor", exc_info=True)
|
|
194
|
+
except Exception as e:
|
|
195
|
+
logger.exception("Query execution failed")
|
|
196
|
+
return {
|
|
197
|
+
"executed": False,
|
|
198
|
+
"validation": validation,
|
|
199
|
+
"error": str(e),
|
|
200
|
+
"next_action": "revise_query",
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
# ------------------------------------------------------------------ helpers
|
|
204
|
+
|
|
205
|
+
def _shape_select(
|
|
206
|
+
self,
|
|
207
|
+
cursor: Any,
|
|
208
|
+
limit: int,
|
|
209
|
+
mode: str,
|
|
210
|
+
budget: Optional[str],
|
|
211
|
+
validation: dict,
|
|
212
|
+
) -> dict:
|
|
213
|
+
columns = [d[0] for d in cursor.description]
|
|
214
|
+
rows = cursor.fetchmany(limit + 1)
|
|
215
|
+
truncated = len(rows) > limit
|
|
216
|
+
rows = rows[:limit]
|
|
217
|
+
|
|
218
|
+
if mode == "count_only":
|
|
219
|
+
return {
|
|
220
|
+
"executed": True,
|
|
221
|
+
"validation": validation,
|
|
222
|
+
"row_count": len(rows),
|
|
223
|
+
"truncated": truncated,
|
|
224
|
+
"next_action": "done",
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
if mode == "summary":
|
|
228
|
+
return {
|
|
229
|
+
"executed": True,
|
|
230
|
+
"validation": validation,
|
|
231
|
+
"columns": columns,
|
|
232
|
+
"row_count": len(rows),
|
|
233
|
+
"truncated": truncated,
|
|
234
|
+
"next_action": "refine_or_done",
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
if mode == "sample":
|
|
238
|
+
cap = min(sample_row_cap(budget), len(rows))
|
|
239
|
+
return {
|
|
240
|
+
"executed": True,
|
|
241
|
+
"validation": validation,
|
|
242
|
+
"columns": columns,
|
|
243
|
+
"row_count": len(rows),
|
|
244
|
+
"truncated": truncated,
|
|
245
|
+
"sample_rows": [list(r) for r in rows[:cap]],
|
|
246
|
+
"sample_size": cap,
|
|
247
|
+
"next_action": "refine_or_done",
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
# default: rows
|
|
251
|
+
return {
|
|
252
|
+
"executed": True,
|
|
253
|
+
"validation": validation,
|
|
254
|
+
"columns": columns,
|
|
255
|
+
"rows": [list(r) for r in rows],
|
|
256
|
+
"row_count": len(rows),
|
|
257
|
+
"truncated": truncated,
|
|
258
|
+
"next_action": "done",
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
def _shape_non_select(
|
|
262
|
+
self,
|
|
263
|
+
cursor: Any,
|
|
264
|
+
conn: Any,
|
|
265
|
+
cap: int,
|
|
266
|
+
strict_cap: bool,
|
|
267
|
+
validation: dict,
|
|
268
|
+
) -> dict:
|
|
269
|
+
affected = cursor.rowcount
|
|
270
|
+
exceeded = affected > cap
|
|
271
|
+
|
|
272
|
+
if strict_cap and exceeded:
|
|
273
|
+
try:
|
|
274
|
+
conn.rollback()
|
|
275
|
+
except Exception:
|
|
276
|
+
logger.warning("Rollback failed", exc_info=True)
|
|
277
|
+
return {
|
|
278
|
+
"executed": False,
|
|
279
|
+
"validation": validation,
|
|
280
|
+
"rows_affected": affected,
|
|
281
|
+
"exceeded_cap": True,
|
|
282
|
+
"error": (
|
|
283
|
+
f"Affected rows {affected} exceeds cap {cap} under "
|
|
284
|
+
f"strict rows-affected policy; transaction rolled back"
|
|
285
|
+
),
|
|
286
|
+
"next_action": "revise_query",
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
conn.commit()
|
|
290
|
+
return {
|
|
291
|
+
"executed": True,
|
|
292
|
+
"validation": validation,
|
|
293
|
+
"rows_affected": affected,
|
|
294
|
+
"exceeded_cap": exceeded,
|
|
295
|
+
"next_action": "done",
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
# ------------------------------------------------------------------ legacy
|
|
299
|
+
|
|
300
|
+
def run_safe_query(
|
|
301
|
+
self,
|
|
302
|
+
sql: str,
|
|
303
|
+
max_rows: Optional[int] = None,
|
|
304
|
+
) -> dict:
|
|
305
|
+
"""Legacy wrapper — preserved for v0.4 clients."""
|
|
306
|
+
return self.execute_query(
|
|
307
|
+
sql,
|
|
308
|
+
max_rows=max_rows,
|
|
309
|
+
response_mode="rows",
|
|
310
|
+
affected_rows_policy="report",
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
__all__ = [
|
|
315
|
+
"QueryService",
|
|
316
|
+
"QueryExecutionMode",
|
|
317
|
+
"AffectedRowsPolicyMode",
|
|
318
|
+
"sample_row_cap",
|
|
319
|
+
"intent_to_dict",
|
|
320
|
+
"SqlIntent",
|
|
321
|
+
]
|