sqlserver-semantic-mcp 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlserver_semantic_mcp/__init__.py +1 -0
- sqlserver_semantic_mcp/config.py +78 -0
- sqlserver_semantic_mcp/domain/__init__.py +0 -0
- sqlserver_semantic_mcp/domain/enums.py +48 -0
- sqlserver_semantic_mcp/domain/models/__init__.py +0 -0
- sqlserver_semantic_mcp/domain/models/column.py +14 -0
- sqlserver_semantic_mcp/domain/models/object.py +13 -0
- sqlserver_semantic_mcp/domain/models/relationship.py +11 -0
- sqlserver_semantic_mcp/domain/models/table.py +29 -0
- sqlserver_semantic_mcp/infrastructure/__init__.py +0 -0
- sqlserver_semantic_mcp/infrastructure/background.py +59 -0
- sqlserver_semantic_mcp/infrastructure/cache/__init__.py +0 -0
- sqlserver_semantic_mcp/infrastructure/cache/semantic.py +132 -0
- sqlserver_semantic_mcp/infrastructure/cache/store.py +152 -0
- sqlserver_semantic_mcp/infrastructure/cache/structural.py +203 -0
- sqlserver_semantic_mcp/infrastructure/connection.py +78 -0
- sqlserver_semantic_mcp/infrastructure/queries/__init__.py +0 -0
- sqlserver_semantic_mcp/infrastructure/queries/comment_queries.py +18 -0
- sqlserver_semantic_mcp/infrastructure/queries/metadata_queries.py +70 -0
- sqlserver_semantic_mcp/infrastructure/queries/object_queries.py +15 -0
- sqlserver_semantic_mcp/main.py +90 -0
- sqlserver_semantic_mcp/policy/__init__.py +0 -0
- sqlserver_semantic_mcp/policy/analyzer.py +194 -0
- sqlserver_semantic_mcp/policy/enforcer.py +104 -0
- sqlserver_semantic_mcp/policy/intents/__init__.py +16 -0
- sqlserver_semantic_mcp/policy/intents/ast_analyzer.py +24 -0
- sqlserver_semantic_mcp/policy/intents/base.py +17 -0
- sqlserver_semantic_mcp/policy/intents/regex_analyzer.py +11 -0
- sqlserver_semantic_mcp/policy/intents/router.py +21 -0
- sqlserver_semantic_mcp/policy/loader.py +90 -0
- sqlserver_semantic_mcp/policy/models.py +43 -0
- sqlserver_semantic_mcp/server/__init__.py +0 -0
- sqlserver_semantic_mcp/server/app.py +125 -0
- sqlserver_semantic_mcp/server/compact.py +74 -0
- sqlserver_semantic_mcp/server/prompts/__init__.py +5 -0
- sqlserver_semantic_mcp/server/prompts/analysis.py +56 -0
- sqlserver_semantic_mcp/server/prompts/discovery.py +55 -0
- sqlserver_semantic_mcp/server/prompts/execution.py +64 -0
- sqlserver_semantic_mcp/server/prompts/registry.py +41 -0
- sqlserver_semantic_mcp/server/resources/__init__.py +1 -0
- sqlserver_semantic_mcp/server/resources/schema.py +144 -0
- sqlserver_semantic_mcp/server/tools/__init__.py +42 -0
- sqlserver_semantic_mcp/server/tools/cache.py +24 -0
- sqlserver_semantic_mcp/server/tools/metadata.py +167 -0
- sqlserver_semantic_mcp/server/tools/metrics.py +44 -0
- sqlserver_semantic_mcp/server/tools/object_tool.py +113 -0
- sqlserver_semantic_mcp/server/tools/policy.py +48 -0
- sqlserver_semantic_mcp/server/tools/query.py +159 -0
- sqlserver_semantic_mcp/server/tools/relationship.py +104 -0
- sqlserver_semantic_mcp/server/tools/semantic.py +112 -0
- sqlserver_semantic_mcp/server/tools/shape.py +204 -0
- sqlserver_semantic_mcp/server/tools/workflow.py +307 -0
- sqlserver_semantic_mcp/services/__init__.py +0 -0
- sqlserver_semantic_mcp/services/metadata_service.py +173 -0
- sqlserver_semantic_mcp/services/metrics_service.py +124 -0
- sqlserver_semantic_mcp/services/object_service.py +187 -0
- sqlserver_semantic_mcp/services/policy_service.py +59 -0
- sqlserver_semantic_mcp/services/query_service.py +321 -0
- sqlserver_semantic_mcp/services/relationship_service.py +160 -0
- sqlserver_semantic_mcp/services/semantic_service.py +277 -0
- sqlserver_semantic_mcp/workflows/__init__.py +26 -0
- sqlserver_semantic_mcp/workflows/bundle.py +157 -0
- sqlserver_semantic_mcp/workflows/contracts.py +64 -0
- sqlserver_semantic_mcp/workflows/discovery_flow.py +116 -0
- sqlserver_semantic_mcp/workflows/facade.py +117 -0
- sqlserver_semantic_mcp/workflows/query_flow.py +120 -0
- sqlserver_semantic_mcp/workflows/recommendations.py +161 -0
- sqlserver_semantic_mcp/workflows/router.py +59 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/METADATA +679 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/RECORD +74 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/WHEEL +5 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/entry_points.txt +2 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/licenses/LICENSE +21 -0
- sqlserver_semantic_mcp-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from ..domain.enums import SqlOperation
|
|
4
|
+
from .analyzer import SqlIntent
|
|
5
|
+
from .models import PolicyProfile
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class EnforcementResult:
|
|
10
|
+
allowed: bool
|
|
11
|
+
reason: str
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
_OP_FIELD = {
|
|
15
|
+
SqlOperation.SELECT: "select",
|
|
16
|
+
SqlOperation.INSERT: "insert",
|
|
17
|
+
SqlOperation.UPDATE: "update",
|
|
18
|
+
SqlOperation.DELETE: "delete",
|
|
19
|
+
SqlOperation.TRUNCATE: "truncate",
|
|
20
|
+
SqlOperation.CREATE: "create",
|
|
21
|
+
SqlOperation.ALTER: "alter",
|
|
22
|
+
SqlOperation.DROP: "drop",
|
|
23
|
+
SqlOperation.EXEC: "execute",
|
|
24
|
+
SqlOperation.EXECUTE: "execute",
|
|
25
|
+
SqlOperation.MERGE: "merge",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _bare(name: str) -> str:
|
|
30
|
+
return name.strip("[]").split(".")[-1].strip("[]")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def enforce(
|
|
34
|
+
intent: SqlIntent, policy: PolicyProfile, database: str = "",
|
|
35
|
+
) -> EnforcementResult:
|
|
36
|
+
op = intent.primary_operation
|
|
37
|
+
if op == SqlOperation.UNKNOWN:
|
|
38
|
+
return EnforcementResult(False, "Unable to determine SQL operation")
|
|
39
|
+
|
|
40
|
+
field = _OP_FIELD.get(op)
|
|
41
|
+
if field is None or not getattr(policy.operations, field, False):
|
|
42
|
+
return EnforcementResult(
|
|
43
|
+
False, f"Operation {op.value} is not allowed by policy"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
if intent.is_multi_statement and not policy.constraints.allow_multi_statement:
|
|
47
|
+
return EnforcementResult(False, "Multi-statement queries are not allowed")
|
|
48
|
+
|
|
49
|
+
if op == SqlOperation.UPDATE and policy.constraints.require_where_for_update \
|
|
50
|
+
and not intent.has_where_clause:
|
|
51
|
+
return EnforcementResult(False, "UPDATE requires a WHERE clause")
|
|
52
|
+
|
|
53
|
+
if op == SqlOperation.DELETE and policy.constraints.require_where_for_delete \
|
|
54
|
+
and not intent.has_where_clause:
|
|
55
|
+
return EnforcementResult(False, "DELETE requires a WHERE clause")
|
|
56
|
+
|
|
57
|
+
if op == SqlOperation.SELECT and policy.constraints.require_top_for_select \
|
|
58
|
+
and not intent.has_top_clause:
|
|
59
|
+
return EnforcementResult(False, "SELECT requires a TOP clause")
|
|
60
|
+
|
|
61
|
+
scope = policy.scope
|
|
62
|
+
|
|
63
|
+
if scope.allowed_databases and database:
|
|
64
|
+
if database not in scope.allowed_databases:
|
|
65
|
+
return EnforcementResult(
|
|
66
|
+
False, f"Database '{database}' is not allowed by policy"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
bare_tables = [_bare(t) for t in intent.affected_tables]
|
|
70
|
+
|
|
71
|
+
if scope.denied_tables:
|
|
72
|
+
for name in bare_tables:
|
|
73
|
+
if name in scope.denied_tables:
|
|
74
|
+
return EnforcementResult(
|
|
75
|
+
False, f"Table '{name}' is denied by policy"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
if scope.allowed_tables:
|
|
79
|
+
for name in bare_tables:
|
|
80
|
+
if name not in scope.allowed_tables:
|
|
81
|
+
return EnforcementResult(
|
|
82
|
+
False,
|
|
83
|
+
f"Table '{name}' is not in the allowed tables list",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if scope.allowed_schemas:
|
|
87
|
+
for t in intent.affected_tables:
|
|
88
|
+
parts = t.strip("[]").split(".")
|
|
89
|
+
if len(parts) == 2:
|
|
90
|
+
schema = parts[0].strip("[]")
|
|
91
|
+
if schema not in scope.allowed_schemas:
|
|
92
|
+
return EnforcementResult(
|
|
93
|
+
False,
|
|
94
|
+
f"Schema '{schema}' is not in the allowed schemas list",
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
# Unqualified table name — cannot verify schema, reject
|
|
98
|
+
return EnforcementResult(
|
|
99
|
+
False,
|
|
100
|
+
f"Table '{t}' is unqualified; allowed_schemas requires "
|
|
101
|
+
f"schema-qualified names (e.g., 'dbo.{t}')",
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return EnforcementResult(True, "Query is allowed")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Pluggable SQL intent analyzers.
|
|
2
|
+
|
|
3
|
+
The regex analyzer is the current default. The AST analyzer is a
|
|
4
|
+
placeholder that falls back to regex until a real parser lands.
|
|
5
|
+
"""
|
|
6
|
+
from .base import IntentAnalyzer
|
|
7
|
+
from .regex_analyzer import RegexIntentAnalyzer
|
|
8
|
+
from .ast_analyzer import AstIntentAnalyzer
|
|
9
|
+
from .router import get_analyzer
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"IntentAnalyzer",
|
|
13
|
+
"RegexIntentAnalyzer",
|
|
14
|
+
"AstIntentAnalyzer",
|
|
15
|
+
"get_analyzer",
|
|
16
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""AST analyzer placeholder.
|
|
2
|
+
|
|
3
|
+
Currently falls back to :class:`RegexIntentAnalyzer`. The slot exists so
|
|
4
|
+
the workflow layer and tests do not need to change when a real parser
|
|
5
|
+
lands.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from ..analyzer import SqlIntent
|
|
10
|
+
from .regex_analyzer import RegexIntentAnalyzer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AstIntentAnalyzer:
|
|
14
|
+
name = "ast"
|
|
15
|
+
|
|
16
|
+
def __init__(self) -> None:
|
|
17
|
+
self._fallback = RegexIntentAnalyzer()
|
|
18
|
+
|
|
19
|
+
def analyze(self, sql: str) -> SqlIntent:
|
|
20
|
+
# Real AST analysis TBD; until then, return the regex result but
|
|
21
|
+
# lower the confidence so routing code can treat it as provisional.
|
|
22
|
+
intent = self._fallback.analyze(sql)
|
|
23
|
+
intent.confidence = min(intent.confidence, 0.7)
|
|
24
|
+
return intent
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Analyzer contract for SQL intent detection."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Protocol
|
|
5
|
+
|
|
6
|
+
from ..analyzer import SqlIntent
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class IntentAnalyzer(Protocol):
|
|
10
|
+
"""Analyze a SQL string and return a :class:`SqlIntent`.
|
|
11
|
+
|
|
12
|
+
Implementations must not raise for malformed SQL; they should return
|
|
13
|
+
an ``UNKNOWN`` intent with ``is_sql_like=False`` instead, so the
|
|
14
|
+
workflow router can send the request down the discovery path.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def analyze(self, sql: str) -> SqlIntent: ...
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Default regex-based analyzer — wraps the existing ``analyze_sql``."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from ..analyzer import SqlIntent, analyze_sql
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RegexIntentAnalyzer:
|
|
8
|
+
name = "regex"
|
|
9
|
+
|
|
10
|
+
def analyze(self, sql: str) -> SqlIntent:
|
|
11
|
+
return analyze_sql(sql)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Selects the active intent analyzer based on :class:`Config`."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from ...config import Config, get_config
|
|
7
|
+
from .ast_analyzer import AstIntentAnalyzer
|
|
8
|
+
from .base import IntentAnalyzer
|
|
9
|
+
from .regex_analyzer import RegexIntentAnalyzer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
_REGISTRY: dict[str, type] = {
|
|
13
|
+
"regex": RegexIntentAnalyzer,
|
|
14
|
+
"ast": AstIntentAnalyzer,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_analyzer(cfg: Optional[Config] = None) -> IntentAnalyzer:
|
|
19
|
+
cfg = cfg or get_config()
|
|
20
|
+
cls = _REGISTRY.get(cfg.intent_analyzer, RegexIntentAnalyzer)
|
|
21
|
+
return cls()
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from ..config import Config, get_config
|
|
7
|
+
from .models import (
|
|
8
|
+
PolicyFile, PolicyProfile, PolicyOperations,
|
|
9
|
+
PolicyConstraints, PolicyScope,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def builtin_readonly() -> PolicyProfile:
|
|
16
|
+
return PolicyProfile(
|
|
17
|
+
profile_name="readonly",
|
|
18
|
+
operations=PolicyOperations(select=True),
|
|
19
|
+
constraints=PolicyConstraints(
|
|
20
|
+
max_rows_returned=1000,
|
|
21
|
+
allow_multi_statement=False,
|
|
22
|
+
),
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def select_profile(pf: PolicyFile, override: Optional[str]) -> PolicyProfile:
|
|
27
|
+
name = override or pf.active_profile
|
|
28
|
+
if name not in pf.profiles:
|
|
29
|
+
raise ValueError(f"Profile '{name}' not found in policy file")
|
|
30
|
+
profile = pf.profiles[name]
|
|
31
|
+
profile.profile_name = name
|
|
32
|
+
return profile
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def apply_env_overrides(profile: PolicyProfile, cfg: Config) -> PolicyProfile:
|
|
36
|
+
data = profile.model_dump()
|
|
37
|
+
data["constraints"]["max_rows_returned"] = cfg.max_rows_returned
|
|
38
|
+
data["constraints"]["max_rows_affected"] = cfg.max_rows_affected
|
|
39
|
+
data["constraints"]["query_timeout_seconds"] = cfg.query_timeout
|
|
40
|
+
return PolicyProfile.model_validate(data)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def load_policy_from_file(
|
|
44
|
+
path: Optional[str], profile_override: Optional[str],
|
|
45
|
+
) -> PolicyProfile:
|
|
46
|
+
if not path:
|
|
47
|
+
logger.warning("No policy file specified; using built-in readonly profile")
|
|
48
|
+
return builtin_readonly()
|
|
49
|
+
|
|
50
|
+
p = Path(path)
|
|
51
|
+
if not p.exists():
|
|
52
|
+
logger.warning(
|
|
53
|
+
"Policy file %s not found; using built-in readonly profile", path,
|
|
54
|
+
)
|
|
55
|
+
return builtin_readonly()
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
text = p.read_text(encoding="utf-8")
|
|
59
|
+
except (OSError, UnicodeDecodeError) as e:
|
|
60
|
+
logger.error(
|
|
61
|
+
"Policy file %s unreadable (%s); falling back to readonly", path, e,
|
|
62
|
+
)
|
|
63
|
+
return builtin_readonly()
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
raw = json.loads(text)
|
|
67
|
+
except json.JSONDecodeError as e:
|
|
68
|
+
logger.error(
|
|
69
|
+
"Policy file %s has invalid JSON (%s); falling back to readonly",
|
|
70
|
+
path, e,
|
|
71
|
+
)
|
|
72
|
+
return builtin_readonly()
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
pf = PolicyFile.model_validate(raw)
|
|
76
|
+
except Exception as e:
|
|
77
|
+
logger.error(
|
|
78
|
+
"Policy file %s failed schema validation (%s); falling back to readonly",
|
|
79
|
+
path, e,
|
|
80
|
+
)
|
|
81
|
+
return builtin_readonly()
|
|
82
|
+
|
|
83
|
+
# Profile-override errors still raise (caller misconfiguration)
|
|
84
|
+
return select_profile(pf, profile_override)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def load_active_policy(cfg: Optional[Config] = None) -> PolicyProfile:
|
|
88
|
+
cfg = cfg or get_config()
|
|
89
|
+
base = load_policy_from_file(cfg.policy_file, cfg.policy_profile)
|
|
90
|
+
return apply_env_overrides(base, cfg)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class PolicyOperations(BaseModel):
|
|
5
|
+
select: bool = True
|
|
6
|
+
insert: bool = False
|
|
7
|
+
update: bool = False
|
|
8
|
+
delete: bool = False
|
|
9
|
+
truncate: bool = False
|
|
10
|
+
create: bool = False
|
|
11
|
+
alter: bool = False
|
|
12
|
+
drop: bool = False
|
|
13
|
+
execute: bool = False
|
|
14
|
+
merge: bool = False
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PolicyConstraints(BaseModel):
|
|
18
|
+
require_where_for_update: bool = True
|
|
19
|
+
require_where_for_delete: bool = True
|
|
20
|
+
require_top_for_select: bool = False
|
|
21
|
+
max_rows_returned: int = Field(default=1000, ge=1)
|
|
22
|
+
max_rows_affected: int = Field(default=100, ge=1)
|
|
23
|
+
allow_multi_statement: bool = False
|
|
24
|
+
query_timeout_seconds: int = Field(default=30, ge=1)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PolicyScope(BaseModel):
|
|
28
|
+
allowed_databases: list[str] = Field(default_factory=list)
|
|
29
|
+
allowed_schemas: list[str] = Field(default_factory=list)
|
|
30
|
+
allowed_tables: list[str] = Field(default_factory=list)
|
|
31
|
+
denied_tables: list[str] = Field(default_factory=list)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PolicyProfile(BaseModel):
|
|
35
|
+
profile_name: str
|
|
36
|
+
operations: PolicyOperations = Field(default_factory=PolicyOperations)
|
|
37
|
+
constraints: PolicyConstraints = Field(default_factory=PolicyConstraints)
|
|
38
|
+
scope: PolicyScope = Field(default_factory=PolicyScope)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class PolicyFile(BaseModel):
|
|
42
|
+
active_profile: str = "readonly"
|
|
43
|
+
profiles: dict[str, PolicyProfile]
|
|
File without changes
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Any, Awaitable, Callable, Optional
|
|
4
|
+
|
|
5
|
+
from mcp.server import Server
|
|
6
|
+
from mcp.types import Tool, TextContent
|
|
7
|
+
|
|
8
|
+
from ..config import Config, get_config
|
|
9
|
+
from ..services import metrics_service
|
|
10
|
+
from ..services.policy_service import PolicyService
|
|
11
|
+
from ..services.query_service import QueryService
|
|
12
|
+
from ..workflows.facade import WorkflowFacade
|
|
13
|
+
from .compact import compact
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
_WORKFLOW_TOOLS = frozenset({
|
|
19
|
+
"plan_or_execute_query",
|
|
20
|
+
"preview_safe_query",
|
|
21
|
+
"discover_relevant_tables",
|
|
22
|
+
"suggest_next_tool",
|
|
23
|
+
"estimate_execution_risk",
|
|
24
|
+
"bundle_context_for_next_step",
|
|
25
|
+
"score_join_candidate",
|
|
26
|
+
"summarize_table_for_joining",
|
|
27
|
+
"summarize_object_for_impact",
|
|
28
|
+
})
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Context:
|
|
32
|
+
def __init__(self, cfg: Config) -> None:
|
|
33
|
+
self.cfg = cfg
|
|
34
|
+
self.policy = PolicyService(cfg)
|
|
35
|
+
self.query = QueryService(self.policy, cfg)
|
|
36
|
+
self.workflow = WorkflowFacade(cfg, self.policy, self.query)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
_ctx: Context | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_context() -> Context:
|
|
43
|
+
global _ctx
|
|
44
|
+
if _ctx is None:
|
|
45
|
+
_ctx = Context(get_config())
|
|
46
|
+
_ctx.policy.load()
|
|
47
|
+
return _ctx
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def reset_context() -> None:
|
|
51
|
+
global _ctx
|
|
52
|
+
_ctx = None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
app = Server("sqlserver-semantic-mcp")
|
|
56
|
+
|
|
57
|
+
ToolHandler = Callable[[dict[str, Any]], Awaitable[Any]]
|
|
58
|
+
_TOOL_REGISTRY: dict[str, tuple[Tool, ToolHandler]] = {}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def register_tool(tool: Tool, handler: ToolHandler) -> None:
|
|
62
|
+
if tool.name in _TOOL_REGISTRY:
|
|
63
|
+
raise ValueError(f"Duplicate tool registration: {tool.name}")
|
|
64
|
+
_TOOL_REGISTRY[tool.name] = (tool, handler)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@app.list_tools()
|
|
68
|
+
async def _list_tools() -> list[Tool]:
|
|
69
|
+
return [t for (t, _) in _TOOL_REGISTRY.values()]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _infer_workflow_metrics(name: str, shaped: Any) -> dict[str, Any]:
|
|
73
|
+
"""Extract workflow-aware fields from the response envelope, if any."""
|
|
74
|
+
extras: dict[str, Any] = {}
|
|
75
|
+
if not isinstance(shaped, dict):
|
|
76
|
+
return extras
|
|
77
|
+
if name in _WORKFLOW_TOOLS:
|
|
78
|
+
extras["route_type"] = shaped.get("kind")
|
|
79
|
+
for key in ("detail", "response_mode", "token_budget_hint",
|
|
80
|
+
"next_action", "bundle_key"):
|
|
81
|
+
if key in shaped:
|
|
82
|
+
extras[key] = shaped.get(key)
|
|
83
|
+
data = shaped.get("data")
|
|
84
|
+
if isinstance(data, dict):
|
|
85
|
+
if "path" in data and "route_type" not in extras:
|
|
86
|
+
extras["route_type"] = data.get("path")
|
|
87
|
+
if name == "plan_or_execute_query":
|
|
88
|
+
extras["was_direct_execute"] = (
|
|
89
|
+
data.get("path") == "direct_execute"
|
|
90
|
+
and bool(data.get("executed"))
|
|
91
|
+
)
|
|
92
|
+
if "response_mode" not in extras and "response_mode" in data:
|
|
93
|
+
extras["response_mode"] = data.get("response_mode")
|
|
94
|
+
if "bundle_key" in shaped:
|
|
95
|
+
extras["bundle_used"] = True
|
|
96
|
+
return extras
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@app.call_tool()
|
|
100
|
+
async def _call_tool(name: str, arguments: dict) -> list[TextContent]:
|
|
101
|
+
if name not in _TOOL_REGISTRY:
|
|
102
|
+
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
|
103
|
+
_t, handler = _TOOL_REGISTRY[name]
|
|
104
|
+
try:
|
|
105
|
+
result = await handler(arguments or {})
|
|
106
|
+
shaped = compact(result)
|
|
107
|
+
text = json.dumps(
|
|
108
|
+
shaped, ensure_ascii=False, default=str, separators=(",", ":"),
|
|
109
|
+
)
|
|
110
|
+
if get_config().metrics_enabled:
|
|
111
|
+
try:
|
|
112
|
+
extras = _infer_workflow_metrics(name, shaped)
|
|
113
|
+
await metrics_service.record_metric(
|
|
114
|
+
get_config().cache_path, name,
|
|
115
|
+
response_bytes=len(text.encode("utf-8")),
|
|
116
|
+
array_length=len(shaped) if isinstance(shaped, list) else None,
|
|
117
|
+
fields_returned=len(shaped) if isinstance(shaped, dict) else None,
|
|
118
|
+
**extras,
|
|
119
|
+
)
|
|
120
|
+
except Exception:
|
|
121
|
+
logger.exception("metrics_service.record_metric failed")
|
|
122
|
+
return [TextContent(type="text", text=text)]
|
|
123
|
+
except Exception as e:
|
|
124
|
+
logger.exception("Tool %s raised", name)
|
|
125
|
+
return [TextContent(type="text", text=json.dumps({"error": str(e)}))]
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Transport-layer response shaping helper.
|
|
2
|
+
|
|
3
|
+
See docs/superpowers/specs/2026-04-19-p0-token-optimization-design.md for rules.
|
|
4
|
+
"""
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
NULLABLE_FALSE_KEEP: frozenset[str] = frozenset({"is_nullable"})
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _is_falsy_strippable(value: Any) -> bool:
|
|
11
|
+
return value is None or value == [] or value == {} or value is False
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _merge_table_id(d: dict) -> dict:
|
|
15
|
+
schema = d.get("schema_name")
|
|
16
|
+
table = d.get("table_name")
|
|
17
|
+
if not (isinstance(schema, str) and isinstance(table, str) and schema and table):
|
|
18
|
+
return d
|
|
19
|
+
out: dict[str, Any] = {}
|
|
20
|
+
merged = False
|
|
21
|
+
for k, v in d.items():
|
|
22
|
+
if k == "schema_name":
|
|
23
|
+
if not merged:
|
|
24
|
+
out["table"] = f"{schema}.{table}"
|
|
25
|
+
merged = True
|
|
26
|
+
elif k == "table_name":
|
|
27
|
+
continue
|
|
28
|
+
else:
|
|
29
|
+
out[k] = v
|
|
30
|
+
return out
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _merge_object_id(d: dict) -> dict:
|
|
34
|
+
schema = d.get("schema")
|
|
35
|
+
name = d.get("object_name")
|
|
36
|
+
if not (isinstance(schema, str) and isinstance(name, str) and schema and name):
|
|
37
|
+
return d
|
|
38
|
+
out: dict[str, Any] = {}
|
|
39
|
+
merged = False
|
|
40
|
+
for k, v in d.items():
|
|
41
|
+
if k == "schema":
|
|
42
|
+
if not merged:
|
|
43
|
+
out["object"] = f"{schema}.{name}"
|
|
44
|
+
merged = True
|
|
45
|
+
elif k == "object_name":
|
|
46
|
+
continue
|
|
47
|
+
elif k == "object_type":
|
|
48
|
+
out["type"] = v
|
|
49
|
+
else:
|
|
50
|
+
out[k] = v
|
|
51
|
+
return out
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def compact(obj: Any) -> Any:
|
|
55
|
+
"""Recursively strip falsy values and merge identifier pairs.
|
|
56
|
+
|
|
57
|
+
Application order within a dict: R2 (table merge) -> R3 (object merge) -> R1 (strip).
|
|
58
|
+
"""
|
|
59
|
+
if isinstance(obj, dict):
|
|
60
|
+
merged = _merge_table_id(obj)
|
|
61
|
+
merged = _merge_object_id(merged)
|
|
62
|
+
out: dict[str, Any] = {}
|
|
63
|
+
for k, v in merged.items():
|
|
64
|
+
v = compact(v)
|
|
65
|
+
if k in NULLABLE_FALSE_KEEP and v is False:
|
|
66
|
+
out[k] = v
|
|
67
|
+
continue
|
|
68
|
+
if _is_falsy_strippable(v):
|
|
69
|
+
continue
|
|
70
|
+
out[k] = v
|
|
71
|
+
return out
|
|
72
|
+
if isinstance(obj, list):
|
|
73
|
+
return [compact(x) for x in obj]
|
|
74
|
+
return obj
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Object / impact analysis prompts."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from mcp.types import (
|
|
5
|
+
GetPromptResult, Prompt, PromptArgument, PromptMessage, TextContent,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from .registry import register_prompt
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_PROMPT = Prompt(
|
|
12
|
+
name="trace_data_impact",
|
|
13
|
+
description=(
|
|
14
|
+
"Trace the downstream impact of changing a view/procedure/function "
|
|
15
|
+
"without dumping raw SQL bodies into the context."
|
|
16
|
+
),
|
|
17
|
+
arguments=[
|
|
18
|
+
PromptArgument(name="schema", required=True),
|
|
19
|
+
PromptArgument(name="name", required=True),
|
|
20
|
+
PromptArgument(
|
|
21
|
+
name="type",
|
|
22
|
+
description="VIEW | PROCEDURE | FUNCTION",
|
|
23
|
+
required=True,
|
|
24
|
+
),
|
|
25
|
+
],
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
_BODY = """You need to understand the impact of modifying {type} {schema}.{name}. Follow the impact chain:
|
|
30
|
+
|
|
31
|
+
1. `summarize_object_for_impact(schema={schema!r}, name={name!r}, type={type!r})` — returns reads / writes / depends_on in compact form.
|
|
32
|
+
2. `trace_object_dependencies(schema={schema!r}, name={name!r}, type={type!r})` — returns the dependency list.
|
|
33
|
+
3. `bundle_context_for_next_step(items=[...], goal="object_impact")` — compress before recommending changes.
|
|
34
|
+
|
|
35
|
+
Only request full definitions (`describe_view` / `describe_procedure` with detail="full") if the summaries leave a concrete gap.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def _handler(arguments: dict) -> GetPromptResult:
|
|
40
|
+
schema = arguments.get("schema", "")
|
|
41
|
+
name = arguments.get("name", "")
|
|
42
|
+
obj_type = (arguments.get("type") or "VIEW").upper()
|
|
43
|
+
text = _BODY.format(schema=schema, name=name, type=obj_type)
|
|
44
|
+
return GetPromptResult(
|
|
45
|
+
description="Impact analysis chain for schema objects.",
|
|
46
|
+
messages=[
|
|
47
|
+
PromptMessage(
|
|
48
|
+
role="user",
|
|
49
|
+
content=TextContent(type="text", text=text),
|
|
50
|
+
),
|
|
51
|
+
],
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def register() -> None:
|
|
56
|
+
register_prompt(_PROMPT, _handler)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""Discovery prompts."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from mcp.types import (
|
|
5
|
+
GetPromptResult, Prompt, PromptArgument, PromptMessage, TextContent,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from .registry import register_prompt
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_PROMPT = Prompt(
|
|
12
|
+
name="discover_tables_for_business_question",
|
|
13
|
+
description=(
|
|
14
|
+
"Translate a natural-language question into the shortest discovery "
|
|
15
|
+
"chain — candidates → describe → join path."
|
|
16
|
+
),
|
|
17
|
+
arguments=[
|
|
18
|
+
PromptArgument(
|
|
19
|
+
name="goal",
|
|
20
|
+
description="Free-form business question.",
|
|
21
|
+
required=True,
|
|
22
|
+
),
|
|
23
|
+
],
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
_BODY = """You have a business question but not the target tables. Follow the discovery chain:
|
|
28
|
+
|
|
29
|
+
1. `discover_relevant_tables(goal={goal!r})` — returns a small ranked candidate set.
|
|
30
|
+
2. For the top 2–3 candidates, call `describe_table(detail="brief")` only. Skip "full" until you must.
|
|
31
|
+
3. If the question joins concepts, call `find_join_path` for each plausible pair, then `score_join_candidate` to pick the best.
|
|
32
|
+
4. When you are confident, draft SQL and call `plan_or_execute_query` with mode="auto".
|
|
33
|
+
|
|
34
|
+
Keep each step's detail level at "brief" unless the prior step surfaced ambiguity.
|
|
35
|
+
|
|
36
|
+
Question: {goal}
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
async def _handler(arguments: dict) -> GetPromptResult:
|
|
41
|
+
goal = arguments.get("goal", "")
|
|
42
|
+
text = _BODY.format(goal=goal)
|
|
43
|
+
return GetPromptResult(
|
|
44
|
+
description="Discovery chain for unknown tables.",
|
|
45
|
+
messages=[
|
|
46
|
+
PromptMessage(
|
|
47
|
+
role="user",
|
|
48
|
+
content=TextContent(type="text", text=text),
|
|
49
|
+
),
|
|
50
|
+
],
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def register() -> None:
|
|
55
|
+
register_prompt(_PROMPT, _handler)
|