sql-code-graph 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_code_graph-0.2.1.dist-info/METADATA +171 -0
- sql_code_graph-0.2.1.dist-info/RECORD +55 -0
- sql_code_graph-0.2.1.dist-info/WHEEL +4 -0
- sql_code_graph-0.2.1.dist-info/entry_points.txt +2 -0
- sqlcg/__init__.py +5 -0
- sqlcg/__main__.py +6 -0
- sqlcg/cli/__init__.py +1 -0
- sqlcg/cli/commands/__init__.py +1 -0
- sqlcg/cli/commands/analyze.py +93 -0
- sqlcg/cli/commands/db.py +83 -0
- sqlcg/cli/commands/find.py +63 -0
- sqlcg/cli/commands/gain.py +169 -0
- sqlcg/cli/commands/git.py +73 -0
- sqlcg/cli/commands/index.py +92 -0
- sqlcg/cli/commands/install.py +60 -0
- sqlcg/cli/commands/mcp.py +54 -0
- sqlcg/cli/commands/report.py +135 -0
- sqlcg/cli/commands/watch.py +57 -0
- sqlcg/cli/main.py +40 -0
- sqlcg/core/__init__.py +8 -0
- sqlcg/core/config.py +104 -0
- sqlcg/core/graph_db.py +179 -0
- sqlcg/core/jobs.py +105 -0
- sqlcg/core/kuzu_backend.py +269 -0
- sqlcg/core/neo4j_backend.py +195 -0
- sqlcg/core/queries.py +82 -0
- sqlcg/core/schema.cypher +104 -0
- sqlcg/core/schema.py +48 -0
- sqlcg/indexer/__init__.py +1 -0
- sqlcg/indexer/dbt_adapter.py +23 -0
- sqlcg/indexer/indexer.py +317 -0
- sqlcg/indexer/walker.py +55 -0
- sqlcg/indexer/watcher.py +195 -0
- sqlcg/lineage/__init__.py +1 -0
- sqlcg/lineage/aggregator.py +58 -0
- sqlcg/lineage/schema_resolver.py +198 -0
- sqlcg/metrics/__init__.py +5 -0
- sqlcg/metrics/store.py +273 -0
- sqlcg/parsers/__init__.py +30 -0
- sqlcg/parsers/ansi_parser.py +215 -0
- sqlcg/parsers/base.py +414 -0
- sqlcg/parsers/bigquery_parser.py +77 -0
- sqlcg/parsers/postgres_parser.py +27 -0
- sqlcg/parsers/registry.py +46 -0
- sqlcg/parsers/snowflake_parser.py +148 -0
- sqlcg/parsers/tsql_parser.py +27 -0
- sqlcg/server/__init__.py +1 -0
- sqlcg/server/exceptions.py +20 -0
- sqlcg/server/models.py +83 -0
- sqlcg/server/server.py +57 -0
- sqlcg/server/tools.py +663 -0
- sqlcg/utils/__init__.py +6 -0
- sqlcg/utils/hashing.py +18 -0
- sqlcg/utils/ignore.py +36 -0
- sqlcg/utils/logging.py +29 -0
sqlcg/server/tools.py
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
1
|
+
"""MCP tools for SQL code graph queries and indexing."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
import time
|
|
5
|
+
from collections import deque
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from sqlcg.core.config import get_db_path
|
|
9
|
+
from sqlcg.core.graph_db import GraphBackend
|
|
10
|
+
from sqlcg.core.kuzu_backend import KuzuBackend
|
|
11
|
+
from sqlcg.core.queries import (
|
|
12
|
+
FIND_TABLE_USAGES_QUERY,
|
|
13
|
+
GET_DOWNSTREAM_DEPENDENCIES_QUERY,
|
|
14
|
+
GET_UPSTREAM_DEPENDENCIES_QUERY,
|
|
15
|
+
INDEX_REPO_FILES_QUERY,
|
|
16
|
+
LIST_DIALECTS_AND_REPOS_QUERY,
|
|
17
|
+
SEARCH_SQL_PATTERN_QUERY,
|
|
18
|
+
TRACE_COLUMN_LINEAGE_QUERY,
|
|
19
|
+
)
|
|
20
|
+
from sqlcg.indexer.indexer import Indexer
|
|
21
|
+
from sqlcg.metrics.store import MetricsStore
|
|
22
|
+
from sqlcg.server.exceptions import InvalidColumnRefError, NotIndexedError
|
|
23
|
+
from sqlcg.server.models import (
|
|
24
|
+
DependencyNode,
|
|
25
|
+
DependencyResult,
|
|
26
|
+
DialectRepo,
|
|
27
|
+
DialectRepoResult,
|
|
28
|
+
LineageNode,
|
|
29
|
+
LineageResult,
|
|
30
|
+
SqlPatternMatch,
|
|
31
|
+
SqlPatternResult,
|
|
32
|
+
TableUsage,
|
|
33
|
+
TableUsageResult,
|
|
34
|
+
)
|
|
35
|
+
from sqlcg.server.server import mcp # noqa: F401
|
|
36
|
+
from sqlcg.utils.logging import getLogger
|
|
37
|
+
|
|
38
|
+
logger = getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
# Module-level singleton backend (KùzuDB single-writer model)
|
|
41
|
+
_backend: GraphBackend | None = None
|
|
42
|
+
|
|
43
|
+
# Module-level metrics store singleton
|
|
44
|
+
_metrics: MetricsStore | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def init_backend(db_path: str | None = None) -> None:
|
|
48
|
+
"""Initialize the module-level backend singleton.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
db_path: Path to KùzuDB database. If None, uses get_db_path().
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
RuntimeError: If backend initialization fails
|
|
55
|
+
"""
|
|
56
|
+
global _backend, _metrics
|
|
57
|
+
path = db_path or str(get_db_path())
|
|
58
|
+
backend = KuzuBackend(path)
|
|
59
|
+
try:
|
|
60
|
+
backend.init_schema()
|
|
61
|
+
except Exception as exc:
|
|
62
|
+
backend.close()
|
|
63
|
+
raise RuntimeError(f"Backend initialization failed: {exc}") from exc
|
|
64
|
+
_backend = backend
|
|
65
|
+
logger.debug(f"Backend initialized: {path}")
|
|
66
|
+
|
|
67
|
+
# Initialize metrics store (best-effort, failures are logged as WARNING)
|
|
68
|
+
try:
|
|
69
|
+
metrics_path = Path.home() / ".sqlcg" / "metrics.db"
|
|
70
|
+
_metrics = MetricsStore(metrics_path)
|
|
71
|
+
_metrics.init_schema()
|
|
72
|
+
except Exception as exc:
|
|
73
|
+
logger.warning(f"Failed to initialize metrics store: {exc}")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def shutdown_backend() -> None:
|
|
77
|
+
"""Shutdown the module-level backend singleton.
|
|
78
|
+
|
|
79
|
+
Closes the database connection and clears the global reference.
|
|
80
|
+
Safe to call multiple times.
|
|
81
|
+
"""
|
|
82
|
+
global _backend, _metrics
|
|
83
|
+
if _backend is not None:
|
|
84
|
+
_backend.close()
|
|
85
|
+
_backend = None
|
|
86
|
+
logger.debug("Backend shut down")
|
|
87
|
+
if _metrics is not None:
|
|
88
|
+
_metrics.close()
|
|
89
|
+
_metrics = None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _get_backend() -> GraphBackend:
|
|
93
|
+
"""Get the initialized backend.
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
RuntimeError: If backend not initialized via init_backend().
|
|
97
|
+
"""
|
|
98
|
+
if _backend is None:
|
|
99
|
+
raise RuntimeError("Backend not initialized. Call init_backend() before using tools.")
|
|
100
|
+
return _backend
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _assert_indexed(db: GraphBackend) -> None:
|
|
104
|
+
"""Check that the graph has indexed repos.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
db: GraphBackend instance
|
|
108
|
+
|
|
109
|
+
Raises:
|
|
110
|
+
NotIndexedError: If no repos have been indexed
|
|
111
|
+
"""
|
|
112
|
+
rows = db.run_read("MATCH (r:Repo) RETURN count(r) AS n", {})
|
|
113
|
+
if not rows or rows[0]["n"] == 0:
|
|
114
|
+
raise NotIndexedError("No repos have been indexed. Run `sqlcg index <path>` first.")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _parse_column_ref(col_ref: str) -> tuple[str, str]:
|
|
118
|
+
"""Parse column reference "table.column" or "catalog.db.table.column".
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
col_ref: Column reference string
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Tuple of (table_id, column_name)
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
InvalidColumnRefError: If format is invalid
|
|
128
|
+
"""
|
|
129
|
+
parts = col_ref.split(".")
|
|
130
|
+
if len(parts) < 2:
|
|
131
|
+
raise InvalidColumnRefError(
|
|
132
|
+
f"Invalid column reference: {col_ref} (expected 'table.column' or "
|
|
133
|
+
f"'catalog.db.table.column')"
|
|
134
|
+
)
|
|
135
|
+
# Last part is the column name, everything before is the table id
|
|
136
|
+
column_name = parts[-1]
|
|
137
|
+
table_id = ".".join(parts[:-1])
|
|
138
|
+
return table_id, column_name
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _record_tool_call(tool_name: str, duration_ms: float, success: bool = True) -> None:
|
|
142
|
+
"""Record a tool call to metrics (best-effort).
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
tool_name: Name of the tool.
|
|
146
|
+
duration_ms: Execution time in milliseconds.
|
|
147
|
+
success: Whether the call succeeded.
|
|
148
|
+
"""
|
|
149
|
+
global _metrics
|
|
150
|
+
if _metrics is not None:
|
|
151
|
+
try:
|
|
152
|
+
_metrics.record_tool_call(tool_name, duration_ms, success)
|
|
153
|
+
except Exception as exc:
|
|
154
|
+
logger.warning(f"Failed to record metrics for {tool_name}: {exc}")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _timed_tool(tool_name: str):
|
|
158
|
+
"""Decorator to record tool execution timing and success.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
tool_name: Name of the tool to record.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def decorator(func):
|
|
165
|
+
def wrapper(*args, **kwargs):
|
|
166
|
+
start_time = time.time()
|
|
167
|
+
try:
|
|
168
|
+
result = func(*args, **kwargs)
|
|
169
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
170
|
+
_record_tool_call(tool_name, duration_ms, True)
|
|
171
|
+
return result
|
|
172
|
+
except Exception:
|
|
173
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
174
|
+
_record_tool_call(tool_name, duration_ms, False)
|
|
175
|
+
raise
|
|
176
|
+
|
|
177
|
+
return wrapper
|
|
178
|
+
|
|
179
|
+
return decorator
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@mcp.tool()
|
|
183
|
+
def index_repo(repo_path: str, dialect: str = "ansi") -> dict:
|
|
184
|
+
"""Index a repository of SQL files.
|
|
185
|
+
|
|
186
|
+
Parses SQL files, extracts table and column definitions, and builds
|
|
187
|
+
lineage edges. Results are persisted to the graph database. Only
|
|
188
|
+
git-tracked files are indexed when the directory is a git repo —
|
|
189
|
+
untracked files, build artifacts, and node_modules are ignored
|
|
190
|
+
automatically. Falls back to a full directory scan when git is
|
|
191
|
+
unavailable.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
repo_path: Root directory path to index
|
|
195
|
+
dialect: SQL dialect (ansi, snowflake, bigquery, postgres, tsql)
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Dict with keys: files_parsed, parse_errors, tables_found, lineage_edges_created
|
|
199
|
+
"""
|
|
200
|
+
global _metrics
|
|
201
|
+
start_time = time.time()
|
|
202
|
+
success = True
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
db = _get_backend()
|
|
206
|
+
indexer = Indexer()
|
|
207
|
+
path = Path(repo_path).resolve()
|
|
208
|
+
if not path.exists():
|
|
209
|
+
raise ValueError(f"Repository path does not exist: {repo_path}")
|
|
210
|
+
if not path.is_dir():
|
|
211
|
+
raise ValueError(f"Repository path is not a directory: {repo_path}")
|
|
212
|
+
|
|
213
|
+
# Ensure the Repo node exists for this repository
|
|
214
|
+
from sqlcg.core.schema import NodeLabel, RelType
|
|
215
|
+
|
|
216
|
+
abs_path = str(path)
|
|
217
|
+
db.upsert_node(
|
|
218
|
+
NodeLabel.REPO,
|
|
219
|
+
abs_path,
|
|
220
|
+
{
|
|
221
|
+
"path": abs_path,
|
|
222
|
+
"name": path.name,
|
|
223
|
+
},
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Index the repository (with absolute path)
|
|
227
|
+
result = indexer.index_repo(path, dialect, db)
|
|
228
|
+
|
|
229
|
+
# Create BELONGS_TO relationships from File nodes to Repo node
|
|
230
|
+
# Query for all File nodes in this repo and link them to the Repo
|
|
231
|
+
repo_prefix = abs_path.rstrip("/") + "/"
|
|
232
|
+
file_rows = db.run_read(INDEX_REPO_FILES_QUERY, {"repo_prefix": repo_prefix})
|
|
233
|
+
for row in file_rows:
|
|
234
|
+
db.upsert_edge(
|
|
235
|
+
NodeLabel.FILE,
|
|
236
|
+
row["path"],
|
|
237
|
+
NodeLabel.REPO,
|
|
238
|
+
abs_path,
|
|
239
|
+
RelType.BELONGS_TO,
|
|
240
|
+
{},
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
logger.info(f"Indexed {result['files_parsed']} files with {result['tables_found']} tables")
|
|
244
|
+
|
|
245
|
+
# Record metrics
|
|
246
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
247
|
+
if _metrics is not None:
|
|
248
|
+
try:
|
|
249
|
+
_metrics.record_index_run(
|
|
250
|
+
abs_path,
|
|
251
|
+
result.get("files_parsed", 0),
|
|
252
|
+
result.get("parse_errors", 0),
|
|
253
|
+
result.get("tables_found", 0),
|
|
254
|
+
result.get("lineage_edges_created", 0),
|
|
255
|
+
duration_ms,
|
|
256
|
+
)
|
|
257
|
+
except Exception as exc:
|
|
258
|
+
logger.warning(f"Failed to record index run metrics: {exc}")
|
|
259
|
+
|
|
260
|
+
return result
|
|
261
|
+
except Exception:
|
|
262
|
+
success = False
|
|
263
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
264
|
+
_record_tool_call("index_repo", duration_ms, success)
|
|
265
|
+
raise
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@mcp.tool()
|
|
269
|
+
@_timed_tool("trace_column_lineage")
|
|
270
|
+
def trace_column_lineage(table_col: str, max_depth: int = 5) -> LineageResult:
|
|
271
|
+
"""Trace upstream lineage of a column.
|
|
272
|
+
|
|
273
|
+
Traverses COLUMN_LINEAGE edges backward up to max_depth levels.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
table_col: Column reference in format "table.column"
|
|
277
|
+
or "catalog.db.table.column"
|
|
278
|
+
max_depth: Maximum number of hops to traverse
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
LineageResult with list of upstream column nodes
|
|
282
|
+
|
|
283
|
+
Raises:
|
|
284
|
+
NotIndexedError: If no repos have been indexed
|
|
285
|
+
InvalidColumnRefError: If column reference format is invalid
|
|
286
|
+
"""
|
|
287
|
+
db = _get_backend()
|
|
288
|
+
_assert_indexed(db)
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
table_id, col_name = _parse_column_ref(table_col)
|
|
292
|
+
except InvalidColumnRefError:
|
|
293
|
+
raise
|
|
294
|
+
|
|
295
|
+
# Construct the full column id
|
|
296
|
+
col_id = f"{table_id}.{col_name}"
|
|
297
|
+
|
|
298
|
+
lineage: list[LineageNode] = []
|
|
299
|
+
visited: set[str] = set()
|
|
300
|
+
queue: deque[tuple[str, int]] = deque([(col_id, 0)])
|
|
301
|
+
|
|
302
|
+
while queue:
|
|
303
|
+
current_id, depth = queue.popleft()
|
|
304
|
+
|
|
305
|
+
if current_id in visited or depth > max_depth:
|
|
306
|
+
continue
|
|
307
|
+
|
|
308
|
+
visited.add(current_id)
|
|
309
|
+
|
|
310
|
+
# Query for upstream columns (reverse direction)
|
|
311
|
+
rows = db.run_read(
|
|
312
|
+
TRACE_COLUMN_LINEAGE_QUERY,
|
|
313
|
+
{"id": current_id},
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
for row in rows:
|
|
317
|
+
node_id = row["id"]
|
|
318
|
+
if node_id not in visited:
|
|
319
|
+
lineage.append(
|
|
320
|
+
LineageNode(
|
|
321
|
+
name=row.get("col_name", ""),
|
|
322
|
+
kind="column",
|
|
323
|
+
file=None,
|
|
324
|
+
confidence=None,
|
|
325
|
+
)
|
|
326
|
+
)
|
|
327
|
+
queue.append((node_id, depth + 1))
|
|
328
|
+
|
|
329
|
+
return LineageResult(column=table_col, lineage=lineage)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
@mcp.tool()
|
|
333
|
+
@_timed_tool("find_table_usages")
|
|
334
|
+
def find_table_usages(table_name: str) -> TableUsageResult:
|
|
335
|
+
"""Find all queries that use a given table.
|
|
336
|
+
|
|
337
|
+
Searches for SELECTS_FROM relationships pointing to the table.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
table_name: Table name to search for
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
TableUsageResult with list of queries using this table
|
|
344
|
+
|
|
345
|
+
Raises:
|
|
346
|
+
NotIndexedError: If no repos have been indexed
|
|
347
|
+
"""
|
|
348
|
+
db = _get_backend()
|
|
349
|
+
_assert_indexed(db)
|
|
350
|
+
|
|
351
|
+
rows = db.run_read(
|
|
352
|
+
FIND_TABLE_USAGES_QUERY,
|
|
353
|
+
{"name": table_name},
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
usages: list[TableUsage] = []
|
|
357
|
+
for row in rows:
|
|
358
|
+
usages.append(
|
|
359
|
+
TableUsage(
|
|
360
|
+
query_file=row["file"],
|
|
361
|
+
sql=row.get("sql"),
|
|
362
|
+
kind=row.get("kind"),
|
|
363
|
+
)
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
return TableUsageResult(table=table_name, usages=usages)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
@mcp.tool()
|
|
370
|
+
@_timed_tool("get_downstream_dependencies")
|
|
371
|
+
def get_downstream_dependencies(table_col: str, max_depth: int = 5) -> DependencyResult:
|
|
372
|
+
"""Find all downstream dependencies of a column.
|
|
373
|
+
|
|
374
|
+
Traverses COLUMN_LINEAGE edges forward to find columns that depend on this one.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
table_col: Column reference in format "table.column"
|
|
378
|
+
or "catalog.db.table.column"
|
|
379
|
+
max_depth: Maximum number of hops to traverse
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
DependencyResult with list of downstream column nodes
|
|
383
|
+
|
|
384
|
+
Raises:
|
|
385
|
+
NotIndexedError: If no repos have been indexed
|
|
386
|
+
InvalidColumnRefError: If column reference format is invalid
|
|
387
|
+
"""
|
|
388
|
+
db = _get_backend()
|
|
389
|
+
_assert_indexed(db)
|
|
390
|
+
|
|
391
|
+
try:
|
|
392
|
+
table_id, col_name = _parse_column_ref(table_col)
|
|
393
|
+
except InvalidColumnRefError:
|
|
394
|
+
raise
|
|
395
|
+
|
|
396
|
+
# Construct the full column id
|
|
397
|
+
col_id = f"{table_id}.{col_name}"
|
|
398
|
+
|
|
399
|
+
nodes: list[DependencyNode] = []
|
|
400
|
+
visited: set[str] = set()
|
|
401
|
+
queue: deque[tuple[str, int]] = deque([(col_id, 0)])
|
|
402
|
+
|
|
403
|
+
while queue:
|
|
404
|
+
current_id, depth = queue.popleft()
|
|
405
|
+
|
|
406
|
+
if current_id in visited or depth > max_depth:
|
|
407
|
+
continue
|
|
408
|
+
|
|
409
|
+
visited.add(current_id)
|
|
410
|
+
|
|
411
|
+
# Query for downstream columns (forward direction)
|
|
412
|
+
rows = db.run_read(
|
|
413
|
+
GET_DOWNSTREAM_DEPENDENCIES_QUERY,
|
|
414
|
+
{"id": current_id},
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
for row in rows:
|
|
418
|
+
node_id = row["id"]
|
|
419
|
+
if node_id not in visited:
|
|
420
|
+
nodes.append(
|
|
421
|
+
DependencyNode(
|
|
422
|
+
name=row.get("col_name", ""),
|
|
423
|
+
kind="column",
|
|
424
|
+
)
|
|
425
|
+
)
|
|
426
|
+
queue.append((node_id, depth + 1))
|
|
427
|
+
|
|
428
|
+
return DependencyResult(root=table_col, nodes=nodes)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
@mcp.tool()
|
|
432
|
+
@_timed_tool("get_upstream_dependencies")
|
|
433
|
+
def get_upstream_dependencies(table_col: str, max_depth: int = 5) -> DependencyResult:
|
|
434
|
+
"""Find all upstream dependencies of a column.
|
|
435
|
+
|
|
436
|
+
Traverses COLUMN_LINEAGE edges backward to find columns this one depends on.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
table_col: Column reference in format "table.column"
|
|
440
|
+
or "catalog.db.table.column"
|
|
441
|
+
max_depth: Maximum number of hops to traverse
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
DependencyResult with list of upstream column nodes
|
|
445
|
+
|
|
446
|
+
Raises:
|
|
447
|
+
NotIndexedError: If no repos have been indexed
|
|
448
|
+
InvalidColumnRefError: If column reference format is invalid
|
|
449
|
+
"""
|
|
450
|
+
db = _get_backend()
|
|
451
|
+
_assert_indexed(db)
|
|
452
|
+
|
|
453
|
+
try:
|
|
454
|
+
table_id, col_name = _parse_column_ref(table_col)
|
|
455
|
+
except InvalidColumnRefError:
|
|
456
|
+
raise
|
|
457
|
+
|
|
458
|
+
# Construct the full column id
|
|
459
|
+
col_id = f"{table_id}.{col_name}"
|
|
460
|
+
|
|
461
|
+
nodes: list[DependencyNode] = []
|
|
462
|
+
visited: set[str] = set()
|
|
463
|
+
queue: deque[tuple[str, int]] = deque([(col_id, 0)])
|
|
464
|
+
|
|
465
|
+
while queue:
|
|
466
|
+
current_id, depth = queue.popleft()
|
|
467
|
+
|
|
468
|
+
if current_id in visited or depth > max_depth:
|
|
469
|
+
continue
|
|
470
|
+
|
|
471
|
+
visited.add(current_id)
|
|
472
|
+
|
|
473
|
+
# Query for upstream columns (reverse direction)
|
|
474
|
+
rows = db.run_read(
|
|
475
|
+
GET_UPSTREAM_DEPENDENCIES_QUERY,
|
|
476
|
+
{"id": current_id},
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
for row in rows:
|
|
480
|
+
node_id = row["id"]
|
|
481
|
+
if node_id not in visited:
|
|
482
|
+
nodes.append(
|
|
483
|
+
DependencyNode(
|
|
484
|
+
name=row.get("col_name", ""),
|
|
485
|
+
kind="column",
|
|
486
|
+
)
|
|
487
|
+
)
|
|
488
|
+
queue.append((node_id, depth + 1))
|
|
489
|
+
|
|
490
|
+
return DependencyResult(root=table_col, nodes=nodes)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
@mcp.tool()
|
|
494
|
+
@_timed_tool("search_sql_pattern")
|
|
495
|
+
def search_sql_pattern(query: str, limit: int = 20) -> SqlPatternResult:
|
|
496
|
+
"""Search for SQL patterns in indexed queries.
|
|
497
|
+
|
|
498
|
+
Uses substring matching on the query SQL text.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
query: Pattern string to search for
|
|
502
|
+
limit: Maximum number of results (default 20)
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
SqlPatternResult with list of matching queries
|
|
506
|
+
|
|
507
|
+
Raises:
|
|
508
|
+
NotIndexedError: If no repos have been indexed
|
|
509
|
+
"""
|
|
510
|
+
db = _get_backend()
|
|
511
|
+
_assert_indexed(db)
|
|
512
|
+
|
|
513
|
+
rows = db.run_read(
|
|
514
|
+
SEARCH_SQL_PATTERN_QUERY,
|
|
515
|
+
{"query": query, "limit": limit},
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
matches: list[SqlPatternMatch] = []
|
|
519
|
+
for row in rows:
|
|
520
|
+
matches.append(
|
|
521
|
+
SqlPatternMatch(
|
|
522
|
+
file=row["file"],
|
|
523
|
+
sql=row.get("sql", ""),
|
|
524
|
+
kind=row.get("kind"),
|
|
525
|
+
)
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
return SqlPatternResult(pattern=query, matches=matches)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
@mcp.tool()
|
|
532
|
+
@_timed_tool("list_dialects_and_repos")
|
|
533
|
+
def list_dialects_and_repos() -> DialectRepoResult:
|
|
534
|
+
"""List all indexed repositories and their SQL dialects.
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
DialectRepoResult with list of repositories and their dialects
|
|
538
|
+
|
|
539
|
+
Raises:
|
|
540
|
+
NotIndexedError: If no repos have been indexed
|
|
541
|
+
"""
|
|
542
|
+
db = _get_backend()
|
|
543
|
+
_assert_indexed(db)
|
|
544
|
+
|
|
545
|
+
rows = db.run_read(
|
|
546
|
+
LIST_DIALECTS_AND_REPOS_QUERY,
|
|
547
|
+
{},
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
repos: list[DialectRepo] = []
|
|
551
|
+
for row in rows:
|
|
552
|
+
repos.append(
|
|
553
|
+
DialectRepo(
|
|
554
|
+
path=row["path"],
|
|
555
|
+
name=row.get("name"),
|
|
556
|
+
dialects=row.get("dialects", []),
|
|
557
|
+
)
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
return DialectRepoResult(repos=repos)
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
@mcp.tool()
|
|
564
|
+
@_timed_tool("execute_cypher")
|
|
565
|
+
def execute_cypher(query: str) -> list[dict]:
|
|
566
|
+
"""Execute a read-only Cypher query against the graph.
|
|
567
|
+
|
|
568
|
+
This tool allows direct Cypher queries for advanced users. It enforces
|
|
569
|
+
read-only mode by stripping quoted literals and checking for write
|
|
570
|
+
operation keywords. A LIMIT clause is automatically appended if missing.
|
|
571
|
+
|
|
572
|
+
**Important Security Note**: This tool strips single and double-quoted
|
|
573
|
+
string literals before checking for write operations. String literals
|
|
574
|
+
containing mutation keywords (e.g., 'DROP TABLE') will NOT trigger the
|
|
575
|
+
write-operation blocker. This is by design to allow querying SQL text
|
|
576
|
+
that contains such keywords.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
query: Cypher query string (read-only)
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
List of result dictionaries from the query
|
|
583
|
+
|
|
584
|
+
Raises:
|
|
585
|
+
ValueError: If the query contains write operations (CREATE, MERGE,
|
|
586
|
+
DELETE, SET, REMOVE, DROP, TRUNCATE)
|
|
587
|
+
"""
|
|
588
|
+
db = _get_backend()
|
|
589
|
+
|
|
590
|
+
# Strip quoted string literals before blocklist check
|
|
591
|
+
# This prevents mutation commands hiding inside strings from triggering the blocker
|
|
592
|
+
# Handle escaped quotes: '' in single quotes, "" in double quotes
|
|
593
|
+
stripped = re.sub(r"'(?:''|[^'])*'", "", query)
|
|
594
|
+
stripped = re.sub(r'"(?:""|[^"])*"', "", stripped)
|
|
595
|
+
|
|
596
|
+
# Check for write operations (case-insensitive)
|
|
597
|
+
if re.search(
|
|
598
|
+
r"\b(CREATE|MERGE|DELETE|SET|REMOVE|DROP|TRUNCATE)\b",
|
|
599
|
+
stripped,
|
|
600
|
+
re.IGNORECASE,
|
|
601
|
+
):
|
|
602
|
+
raise ValueError(
|
|
603
|
+
"Write operations are not permitted via execute_cypher. "
|
|
604
|
+
"Use the CLI or dedicated tools instead."
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
# Auto-append LIMIT if missing
|
|
608
|
+
q = query.rstrip()
|
|
609
|
+
if q.endswith(";"):
|
|
610
|
+
q = q[:-1].rstrip()
|
|
611
|
+
if "limit" not in stripped.lower(): # use stripped, not q.lower()
|
|
612
|
+
q = q + " LIMIT 500"
|
|
613
|
+
|
|
614
|
+
try:
|
|
615
|
+
return db.run_read(q, {})
|
|
616
|
+
except Exception as e:
|
|
617
|
+
logger.error(f"Cypher execution failed: {e}")
|
|
618
|
+
raise
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
@mcp.tool()
|
|
622
|
+
def submit_feedback(
|
|
623
|
+
tool_name: str,
|
|
624
|
+
query: str,
|
|
625
|
+
label: str,
|
|
626
|
+
note: str = "",
|
|
627
|
+
) -> dict:
|
|
628
|
+
"""Submit feedback on a tool result.
|
|
629
|
+
|
|
630
|
+
This tool allows users to correct the MCP server's results. Feedback
|
|
631
|
+
is collected and analyzed to identify patterns and false positives.
|
|
632
|
+
|
|
633
|
+
**For Claude**: When a user says "that result was wrong" or "this is a
|
|
634
|
+
false positive", call this tool with label="FP". When they confirm
|
|
635
|
+
"that's correct", call with label="TP". Use the query or pattern as
|
|
636
|
+
the 'query' argument and include any user feedback in the 'note'.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
tool_name: Name of the tool being evaluated (e.g., "trace_column_lineage")
|
|
640
|
+
query: The query or pattern that was evaluated
|
|
641
|
+
label: Feedback label: "TP" (true positive) or "FP" (false positive)
|
|
642
|
+
note: Optional user note (truncated to 500 chars)
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
Dict with status: "recorded" or "skipped"
|
|
646
|
+
|
|
647
|
+
Raises:
|
|
648
|
+
ValueError: If label is not "TP" or "FP"
|
|
649
|
+
"""
|
|
650
|
+
global _metrics
|
|
651
|
+
|
|
652
|
+
if label not in ("TP", "FP"):
|
|
653
|
+
raise ValueError(f"Invalid label: {label}. Must be 'TP' or 'FP'.")
|
|
654
|
+
|
|
655
|
+
if _metrics is not None:
|
|
656
|
+
try:
|
|
657
|
+
_metrics.record_feedback(tool_name, query, label, note)
|
|
658
|
+
return {"status": "recorded"}
|
|
659
|
+
except Exception as exc:
|
|
660
|
+
logger.warning(f"Failed to record feedback: {exc}")
|
|
661
|
+
return {"status": "skipped"}
|
|
662
|
+
else:
|
|
663
|
+
return {"status": "skipped"}
|
sqlcg/utils/__init__.py
ADDED
sqlcg/utils/hashing.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""SQL hashing utilities."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def hash_sql(sql: str) -> str:
|
|
7
|
+
"""Generate a SHA-256 hash of SQL content.
|
|
8
|
+
|
|
9
|
+
The SQL is normalized by stripping leading and trailing whitespace before hashing.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
sql: SQL statement string
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
SHA-256 hex digest of the normalized SQL bytes
|
|
16
|
+
"""
|
|
17
|
+
normalized_sql = sql.strip()
|
|
18
|
+
return hashlib.sha256(normalized_sql.encode("utf-8")).hexdigest()
|