sql-code-graph 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. sql_code_graph-0.2.1.dist-info/METADATA +171 -0
  2. sql_code_graph-0.2.1.dist-info/RECORD +55 -0
  3. sql_code_graph-0.2.1.dist-info/WHEEL +4 -0
  4. sql_code_graph-0.2.1.dist-info/entry_points.txt +2 -0
  5. sqlcg/__init__.py +5 -0
  6. sqlcg/__main__.py +6 -0
  7. sqlcg/cli/__init__.py +1 -0
  8. sqlcg/cli/commands/__init__.py +1 -0
  9. sqlcg/cli/commands/analyze.py +93 -0
  10. sqlcg/cli/commands/db.py +83 -0
  11. sqlcg/cli/commands/find.py +63 -0
  12. sqlcg/cli/commands/gain.py +169 -0
  13. sqlcg/cli/commands/git.py +73 -0
  14. sqlcg/cli/commands/index.py +92 -0
  15. sqlcg/cli/commands/install.py +60 -0
  16. sqlcg/cli/commands/mcp.py +54 -0
  17. sqlcg/cli/commands/report.py +135 -0
  18. sqlcg/cli/commands/watch.py +57 -0
  19. sqlcg/cli/main.py +40 -0
  20. sqlcg/core/__init__.py +8 -0
  21. sqlcg/core/config.py +104 -0
  22. sqlcg/core/graph_db.py +179 -0
  23. sqlcg/core/jobs.py +105 -0
  24. sqlcg/core/kuzu_backend.py +269 -0
  25. sqlcg/core/neo4j_backend.py +195 -0
  26. sqlcg/core/queries.py +82 -0
  27. sqlcg/core/schema.cypher +104 -0
  28. sqlcg/core/schema.py +48 -0
  29. sqlcg/indexer/__init__.py +1 -0
  30. sqlcg/indexer/dbt_adapter.py +23 -0
  31. sqlcg/indexer/indexer.py +317 -0
  32. sqlcg/indexer/walker.py +55 -0
  33. sqlcg/indexer/watcher.py +195 -0
  34. sqlcg/lineage/__init__.py +1 -0
  35. sqlcg/lineage/aggregator.py +58 -0
  36. sqlcg/lineage/schema_resolver.py +198 -0
  37. sqlcg/metrics/__init__.py +5 -0
  38. sqlcg/metrics/store.py +273 -0
  39. sqlcg/parsers/__init__.py +30 -0
  40. sqlcg/parsers/ansi_parser.py +215 -0
  41. sqlcg/parsers/base.py +414 -0
  42. sqlcg/parsers/bigquery_parser.py +77 -0
  43. sqlcg/parsers/postgres_parser.py +27 -0
  44. sqlcg/parsers/registry.py +46 -0
  45. sqlcg/parsers/snowflake_parser.py +148 -0
  46. sqlcg/parsers/tsql_parser.py +27 -0
  47. sqlcg/server/__init__.py +1 -0
  48. sqlcg/server/exceptions.py +20 -0
  49. sqlcg/server/models.py +83 -0
  50. sqlcg/server/server.py +57 -0
  51. sqlcg/server/tools.py +663 -0
  52. sqlcg/utils/__init__.py +6 -0
  53. sqlcg/utils/hashing.py +18 -0
  54. sqlcg/utils/ignore.py +36 -0
  55. sqlcg/utils/logging.py +29 -0
sqlcg/server/tools.py ADDED
@@ -0,0 +1,663 @@
1
+ """MCP tools for SQL code graph queries and indexing."""
2
+
3
+ import re
4
+ import time
5
+ from collections import deque
6
+ from pathlib import Path
7
+
8
+ from sqlcg.core.config import get_db_path
9
+ from sqlcg.core.graph_db import GraphBackend
10
+ from sqlcg.core.kuzu_backend import KuzuBackend
11
+ from sqlcg.core.queries import (
12
+ FIND_TABLE_USAGES_QUERY,
13
+ GET_DOWNSTREAM_DEPENDENCIES_QUERY,
14
+ GET_UPSTREAM_DEPENDENCIES_QUERY,
15
+ INDEX_REPO_FILES_QUERY,
16
+ LIST_DIALECTS_AND_REPOS_QUERY,
17
+ SEARCH_SQL_PATTERN_QUERY,
18
+ TRACE_COLUMN_LINEAGE_QUERY,
19
+ )
20
+ from sqlcg.indexer.indexer import Indexer
21
+ from sqlcg.metrics.store import MetricsStore
22
+ from sqlcg.server.exceptions import InvalidColumnRefError, NotIndexedError
23
+ from sqlcg.server.models import (
24
+ DependencyNode,
25
+ DependencyResult,
26
+ DialectRepo,
27
+ DialectRepoResult,
28
+ LineageNode,
29
+ LineageResult,
30
+ SqlPatternMatch,
31
+ SqlPatternResult,
32
+ TableUsage,
33
+ TableUsageResult,
34
+ )
35
+ from sqlcg.server.server import mcp # noqa: F401
36
+ from sqlcg.utils.logging import getLogger
37
+
38
+ logger = getLogger(__name__)
39
+
40
+ # Module-level singleton backend (KùzuDB single-writer model)
41
+ _backend: GraphBackend | None = None
42
+
43
+ # Module-level metrics store singleton
44
+ _metrics: MetricsStore | None = None
45
+
46
+
47
+ def init_backend(db_path: str | None = None) -> None:
48
+ """Initialize the module-level backend singleton.
49
+
50
+ Args:
51
+ db_path: Path to KùzuDB database. If None, uses get_db_path().
52
+
53
+ Raises:
54
+ RuntimeError: If backend initialization fails
55
+ """
56
+ global _backend, _metrics
57
+ path = db_path or str(get_db_path())
58
+ backend = KuzuBackend(path)
59
+ try:
60
+ backend.init_schema()
61
+ except Exception as exc:
62
+ backend.close()
63
+ raise RuntimeError(f"Backend initialization failed: {exc}") from exc
64
+ _backend = backend
65
+ logger.debug(f"Backend initialized: {path}")
66
+
67
+ # Initialize metrics store (best-effort, failures are logged as WARNING)
68
+ try:
69
+ metrics_path = Path.home() / ".sqlcg" / "metrics.db"
70
+ _metrics = MetricsStore(metrics_path)
71
+ _metrics.init_schema()
72
+ except Exception as exc:
73
+ logger.warning(f"Failed to initialize metrics store: {exc}")
74
+
75
+
76
+ def shutdown_backend() -> None:
77
+ """Shutdown the module-level backend singleton.
78
+
79
+ Closes the database connection and clears the global reference.
80
+ Safe to call multiple times.
81
+ """
82
+ global _backend, _metrics
83
+ if _backend is not None:
84
+ _backend.close()
85
+ _backend = None
86
+ logger.debug("Backend shut down")
87
+ if _metrics is not None:
88
+ _metrics.close()
89
+ _metrics = None
90
+
91
+
92
+ def _get_backend() -> GraphBackend:
93
+ """Get the initialized backend.
94
+
95
+ Raises:
96
+ RuntimeError: If backend not initialized via init_backend().
97
+ """
98
+ if _backend is None:
99
+ raise RuntimeError("Backend not initialized. Call init_backend() before using tools.")
100
+ return _backend
101
+
102
+
103
+ def _assert_indexed(db: GraphBackend) -> None:
104
+ """Check that the graph has indexed repos.
105
+
106
+ Args:
107
+ db: GraphBackend instance
108
+
109
+ Raises:
110
+ NotIndexedError: If no repos have been indexed
111
+ """
112
+ rows = db.run_read("MATCH (r:Repo) RETURN count(r) AS n", {})
113
+ if not rows or rows[0]["n"] == 0:
114
+ raise NotIndexedError("No repos have been indexed. Run `sqlcg index <path>` first.")
115
+
116
+
117
+ def _parse_column_ref(col_ref: str) -> tuple[str, str]:
118
+ """Parse column reference "table.column" or "catalog.db.table.column".
119
+
120
+ Args:
121
+ col_ref: Column reference string
122
+
123
+ Returns:
124
+ Tuple of (table_id, column_name)
125
+
126
+ Raises:
127
+ InvalidColumnRefError: If format is invalid
128
+ """
129
+ parts = col_ref.split(".")
130
+ if len(parts) < 2:
131
+ raise InvalidColumnRefError(
132
+ f"Invalid column reference: {col_ref} (expected 'table.column' or "
133
+ f"'catalog.db.table.column')"
134
+ )
135
+ # Last part is the column name, everything before is the table id
136
+ column_name = parts[-1]
137
+ table_id = ".".join(parts[:-1])
138
+ return table_id, column_name
139
+
140
+
141
+ def _record_tool_call(tool_name: str, duration_ms: float, success: bool = True) -> None:
142
+ """Record a tool call to metrics (best-effort).
143
+
144
+ Args:
145
+ tool_name: Name of the tool.
146
+ duration_ms: Execution time in milliseconds.
147
+ success: Whether the call succeeded.
148
+ """
149
+ global _metrics
150
+ if _metrics is not None:
151
+ try:
152
+ _metrics.record_tool_call(tool_name, duration_ms, success)
153
+ except Exception as exc:
154
+ logger.warning(f"Failed to record metrics for {tool_name}: {exc}")
155
+
156
+
157
+ def _timed_tool(tool_name: str):
158
+ """Decorator to record tool execution timing and success.
159
+
160
+ Args:
161
+ tool_name: Name of the tool to record.
162
+ """
163
+
164
+ def decorator(func):
165
+ def wrapper(*args, **kwargs):
166
+ start_time = time.time()
167
+ try:
168
+ result = func(*args, **kwargs)
169
+ duration_ms = (time.time() - start_time) * 1000
170
+ _record_tool_call(tool_name, duration_ms, True)
171
+ return result
172
+ except Exception:
173
+ duration_ms = (time.time() - start_time) * 1000
174
+ _record_tool_call(tool_name, duration_ms, False)
175
+ raise
176
+
177
+ return wrapper
178
+
179
+ return decorator
180
+
181
+
182
+ @mcp.tool()
183
+ def index_repo(repo_path: str, dialect: str = "ansi") -> dict:
184
+ """Index a repository of SQL files.
185
+
186
+ Parses SQL files, extracts table and column definitions, and builds
187
+ lineage edges. Results are persisted to the graph database. Only
188
+ git-tracked files are indexed when the directory is a git repo —
189
+ untracked files, build artifacts, and node_modules are ignored
190
+ automatically. Falls back to a full directory scan when git is
191
+ unavailable.
192
+
193
+ Args:
194
+ repo_path: Root directory path to index
195
+ dialect: SQL dialect (ansi, snowflake, bigquery, postgres, tsql)
196
+
197
+ Returns:
198
+ Dict with keys: files_parsed, parse_errors, tables_found, lineage_edges_created
199
+ """
200
+ global _metrics
201
+ start_time = time.time()
202
+ success = True
203
+
204
+ try:
205
+ db = _get_backend()
206
+ indexer = Indexer()
207
+ path = Path(repo_path).resolve()
208
+ if not path.exists():
209
+ raise ValueError(f"Repository path does not exist: {repo_path}")
210
+ if not path.is_dir():
211
+ raise ValueError(f"Repository path is not a directory: {repo_path}")
212
+
213
+ # Ensure the Repo node exists for this repository
214
+ from sqlcg.core.schema import NodeLabel, RelType
215
+
216
+ abs_path = str(path)
217
+ db.upsert_node(
218
+ NodeLabel.REPO,
219
+ abs_path,
220
+ {
221
+ "path": abs_path,
222
+ "name": path.name,
223
+ },
224
+ )
225
+
226
+ # Index the repository (with absolute path)
227
+ result = indexer.index_repo(path, dialect, db)
228
+
229
+ # Create BELONGS_TO relationships from File nodes to Repo node
230
+ # Query for all File nodes in this repo and link them to the Repo
231
+ repo_prefix = abs_path.rstrip("/") + "/"
232
+ file_rows = db.run_read(INDEX_REPO_FILES_QUERY, {"repo_prefix": repo_prefix})
233
+ for row in file_rows:
234
+ db.upsert_edge(
235
+ NodeLabel.FILE,
236
+ row["path"],
237
+ NodeLabel.REPO,
238
+ abs_path,
239
+ RelType.BELONGS_TO,
240
+ {},
241
+ )
242
+
243
+ logger.info(f"Indexed {result['files_parsed']} files with {result['tables_found']} tables")
244
+
245
+ # Record metrics
246
+ duration_ms = (time.time() - start_time) * 1000
247
+ if _metrics is not None:
248
+ try:
249
+ _metrics.record_index_run(
250
+ abs_path,
251
+ result.get("files_parsed", 0),
252
+ result.get("parse_errors", 0),
253
+ result.get("tables_found", 0),
254
+ result.get("lineage_edges_created", 0),
255
+ duration_ms,
256
+ )
257
+ except Exception as exc:
258
+ logger.warning(f"Failed to record index run metrics: {exc}")
259
+
260
+ return result
261
+ except Exception:
262
+ success = False
263
+ duration_ms = (time.time() - start_time) * 1000
264
+ _record_tool_call("index_repo", duration_ms, success)
265
+ raise
266
+
267
+
268
+ @mcp.tool()
269
+ @_timed_tool("trace_column_lineage")
270
+ def trace_column_lineage(table_col: str, max_depth: int = 5) -> LineageResult:
271
+ """Trace upstream lineage of a column.
272
+
273
+ Traverses COLUMN_LINEAGE edges backward up to max_depth levels.
274
+
275
+ Args:
276
+ table_col: Column reference in format "table.column"
277
+ or "catalog.db.table.column"
278
+ max_depth: Maximum number of hops to traverse
279
+
280
+ Returns:
281
+ LineageResult with list of upstream column nodes
282
+
283
+ Raises:
284
+ NotIndexedError: If no repos have been indexed
285
+ InvalidColumnRefError: If column reference format is invalid
286
+ """
287
+ db = _get_backend()
288
+ _assert_indexed(db)
289
+
290
+ try:
291
+ table_id, col_name = _parse_column_ref(table_col)
292
+ except InvalidColumnRefError:
293
+ raise
294
+
295
+ # Construct the full column id
296
+ col_id = f"{table_id}.{col_name}"
297
+
298
+ lineage: list[LineageNode] = []
299
+ visited: set[str] = set()
300
+ queue: deque[tuple[str, int]] = deque([(col_id, 0)])
301
+
302
+ while queue:
303
+ current_id, depth = queue.popleft()
304
+
305
+ if current_id in visited or depth > max_depth:
306
+ continue
307
+
308
+ visited.add(current_id)
309
+
310
+ # Query for upstream columns (reverse direction)
311
+ rows = db.run_read(
312
+ TRACE_COLUMN_LINEAGE_QUERY,
313
+ {"id": current_id},
314
+ )
315
+
316
+ for row in rows:
317
+ node_id = row["id"]
318
+ if node_id not in visited:
319
+ lineage.append(
320
+ LineageNode(
321
+ name=row.get("col_name", ""),
322
+ kind="column",
323
+ file=None,
324
+ confidence=None,
325
+ )
326
+ )
327
+ queue.append((node_id, depth + 1))
328
+
329
+ return LineageResult(column=table_col, lineage=lineage)
330
+
331
+
332
+ @mcp.tool()
333
+ @_timed_tool("find_table_usages")
334
+ def find_table_usages(table_name: str) -> TableUsageResult:
335
+ """Find all queries that use a given table.
336
+
337
+ Searches for SELECTS_FROM relationships pointing to the table.
338
+
339
+ Args:
340
+ table_name: Table name to search for
341
+
342
+ Returns:
343
+ TableUsageResult with list of queries using this table
344
+
345
+ Raises:
346
+ NotIndexedError: If no repos have been indexed
347
+ """
348
+ db = _get_backend()
349
+ _assert_indexed(db)
350
+
351
+ rows = db.run_read(
352
+ FIND_TABLE_USAGES_QUERY,
353
+ {"name": table_name},
354
+ )
355
+
356
+ usages: list[TableUsage] = []
357
+ for row in rows:
358
+ usages.append(
359
+ TableUsage(
360
+ query_file=row["file"],
361
+ sql=row.get("sql"),
362
+ kind=row.get("kind"),
363
+ )
364
+ )
365
+
366
+ return TableUsageResult(table=table_name, usages=usages)
367
+
368
+
369
+ @mcp.tool()
370
+ @_timed_tool("get_downstream_dependencies")
371
+ def get_downstream_dependencies(table_col: str, max_depth: int = 5) -> DependencyResult:
372
+ """Find all downstream dependencies of a column.
373
+
374
+ Traverses COLUMN_LINEAGE edges forward to find columns that depend on this one.
375
+
376
+ Args:
377
+ table_col: Column reference in format "table.column"
378
+ or "catalog.db.table.column"
379
+ max_depth: Maximum number of hops to traverse
380
+
381
+ Returns:
382
+ DependencyResult with list of downstream column nodes
383
+
384
+ Raises:
385
+ NotIndexedError: If no repos have been indexed
386
+ InvalidColumnRefError: If column reference format is invalid
387
+ """
388
+ db = _get_backend()
389
+ _assert_indexed(db)
390
+
391
+ try:
392
+ table_id, col_name = _parse_column_ref(table_col)
393
+ except InvalidColumnRefError:
394
+ raise
395
+
396
+ # Construct the full column id
397
+ col_id = f"{table_id}.{col_name}"
398
+
399
+ nodes: list[DependencyNode] = []
400
+ visited: set[str] = set()
401
+ queue: deque[tuple[str, int]] = deque([(col_id, 0)])
402
+
403
+ while queue:
404
+ current_id, depth = queue.popleft()
405
+
406
+ if current_id in visited or depth > max_depth:
407
+ continue
408
+
409
+ visited.add(current_id)
410
+
411
+ # Query for downstream columns (forward direction)
412
+ rows = db.run_read(
413
+ GET_DOWNSTREAM_DEPENDENCIES_QUERY,
414
+ {"id": current_id},
415
+ )
416
+
417
+ for row in rows:
418
+ node_id = row["id"]
419
+ if node_id not in visited:
420
+ nodes.append(
421
+ DependencyNode(
422
+ name=row.get("col_name", ""),
423
+ kind="column",
424
+ )
425
+ )
426
+ queue.append((node_id, depth + 1))
427
+
428
+ return DependencyResult(root=table_col, nodes=nodes)
429
+
430
+
431
+ @mcp.tool()
432
+ @_timed_tool("get_upstream_dependencies")
433
+ def get_upstream_dependencies(table_col: str, max_depth: int = 5) -> DependencyResult:
434
+ """Find all upstream dependencies of a column.
435
+
436
+ Traverses COLUMN_LINEAGE edges backward to find columns this one depends on.
437
+
438
+ Args:
439
+ table_col: Column reference in format "table.column"
440
+ or "catalog.db.table.column"
441
+ max_depth: Maximum number of hops to traverse
442
+
443
+ Returns:
444
+ DependencyResult with list of upstream column nodes
445
+
446
+ Raises:
447
+ NotIndexedError: If no repos have been indexed
448
+ InvalidColumnRefError: If column reference format is invalid
449
+ """
450
+ db = _get_backend()
451
+ _assert_indexed(db)
452
+
453
+ try:
454
+ table_id, col_name = _parse_column_ref(table_col)
455
+ except InvalidColumnRefError:
456
+ raise
457
+
458
+ # Construct the full column id
459
+ col_id = f"{table_id}.{col_name}"
460
+
461
+ nodes: list[DependencyNode] = []
462
+ visited: set[str] = set()
463
+ queue: deque[tuple[str, int]] = deque([(col_id, 0)])
464
+
465
+ while queue:
466
+ current_id, depth = queue.popleft()
467
+
468
+ if current_id in visited or depth > max_depth:
469
+ continue
470
+
471
+ visited.add(current_id)
472
+
473
+ # Query for upstream columns (reverse direction)
474
+ rows = db.run_read(
475
+ GET_UPSTREAM_DEPENDENCIES_QUERY,
476
+ {"id": current_id},
477
+ )
478
+
479
+ for row in rows:
480
+ node_id = row["id"]
481
+ if node_id not in visited:
482
+ nodes.append(
483
+ DependencyNode(
484
+ name=row.get("col_name", ""),
485
+ kind="column",
486
+ )
487
+ )
488
+ queue.append((node_id, depth + 1))
489
+
490
+ return DependencyResult(root=table_col, nodes=nodes)
491
+
492
+
493
+ @mcp.tool()
494
+ @_timed_tool("search_sql_pattern")
495
+ def search_sql_pattern(query: str, limit: int = 20) -> SqlPatternResult:
496
+ """Search for SQL patterns in indexed queries.
497
+
498
+ Uses substring matching on the query SQL text.
499
+
500
+ Args:
501
+ query: Pattern string to search for
502
+ limit: Maximum number of results (default 20)
503
+
504
+ Returns:
505
+ SqlPatternResult with list of matching queries
506
+
507
+ Raises:
508
+ NotIndexedError: If no repos have been indexed
509
+ """
510
+ db = _get_backend()
511
+ _assert_indexed(db)
512
+
513
+ rows = db.run_read(
514
+ SEARCH_SQL_PATTERN_QUERY,
515
+ {"query": query, "limit": limit},
516
+ )
517
+
518
+ matches: list[SqlPatternMatch] = []
519
+ for row in rows:
520
+ matches.append(
521
+ SqlPatternMatch(
522
+ file=row["file"],
523
+ sql=row.get("sql", ""),
524
+ kind=row.get("kind"),
525
+ )
526
+ )
527
+
528
+ return SqlPatternResult(pattern=query, matches=matches)
529
+
530
+
531
+ @mcp.tool()
532
+ @_timed_tool("list_dialects_and_repos")
533
+ def list_dialects_and_repos() -> DialectRepoResult:
534
+ """List all indexed repositories and their SQL dialects.
535
+
536
+ Returns:
537
+ DialectRepoResult with list of repositories and their dialects
538
+
539
+ Raises:
540
+ NotIndexedError: If no repos have been indexed
541
+ """
542
+ db = _get_backend()
543
+ _assert_indexed(db)
544
+
545
+ rows = db.run_read(
546
+ LIST_DIALECTS_AND_REPOS_QUERY,
547
+ {},
548
+ )
549
+
550
+ repos: list[DialectRepo] = []
551
+ for row in rows:
552
+ repos.append(
553
+ DialectRepo(
554
+ path=row["path"],
555
+ name=row.get("name"),
556
+ dialects=row.get("dialects", []),
557
+ )
558
+ )
559
+
560
+ return DialectRepoResult(repos=repos)
561
+
562
+
563
+ @mcp.tool()
564
+ @_timed_tool("execute_cypher")
565
+ def execute_cypher(query: str) -> list[dict]:
566
+ """Execute a read-only Cypher query against the graph.
567
+
568
+ This tool allows direct Cypher queries for advanced users. It enforces
569
+ read-only mode by stripping quoted literals and checking for write
570
+ operation keywords. A LIMIT clause is automatically appended if missing.
571
+
572
+ **Important Security Note**: This tool strips single and double-quoted
573
+ string literals before checking for write operations. String literals
574
+ containing mutation keywords (e.g., 'DROP TABLE') will NOT trigger the
575
+ write-operation blocker. This is by design to allow querying SQL text
576
+ that contains such keywords.
577
+
578
+ Args:
579
+ query: Cypher query string (read-only)
580
+
581
+ Returns:
582
+ List of result dictionaries from the query
583
+
584
+ Raises:
585
+ ValueError: If the query contains write operations (CREATE, MERGE,
586
+ DELETE, SET, REMOVE, DROP, TRUNCATE)
587
+ """
588
+ db = _get_backend()
589
+
590
+ # Strip quoted string literals before blocklist check
591
+ # This prevents mutation commands hiding inside strings from triggering the blocker
592
+ # Handle escaped quotes: '' in single quotes, "" in double quotes
593
+ stripped = re.sub(r"'(?:''|[^'])*'", "", query)
594
+ stripped = re.sub(r'"(?:""|[^"])*"', "", stripped)
595
+
596
+ # Check for write operations (case-insensitive)
597
+ if re.search(
598
+ r"\b(CREATE|MERGE|DELETE|SET|REMOVE|DROP|TRUNCATE)\b",
599
+ stripped,
600
+ re.IGNORECASE,
601
+ ):
602
+ raise ValueError(
603
+ "Write operations are not permitted via execute_cypher. "
604
+ "Use the CLI or dedicated tools instead."
605
+ )
606
+
607
+ # Auto-append LIMIT if missing
608
+ q = query.rstrip()
609
+ if q.endswith(";"):
610
+ q = q[:-1].rstrip()
611
+ if "limit" not in stripped.lower(): # use stripped, not q.lower()
612
+ q = q + " LIMIT 500"
613
+
614
+ try:
615
+ return db.run_read(q, {})
616
+ except Exception as e:
617
+ logger.error(f"Cypher execution failed: {e}")
618
+ raise
619
+
620
+
621
+ @mcp.tool()
622
+ def submit_feedback(
623
+ tool_name: str,
624
+ query: str,
625
+ label: str,
626
+ note: str = "",
627
+ ) -> dict:
628
+ """Submit feedback on a tool result.
629
+
630
+ This tool allows users to correct the MCP server's results. Feedback
631
+ is collected and analyzed to identify patterns and false positives.
632
+
633
+ **For Claude**: When a user says "that result was wrong" or "this is a
634
+ false positive", call this tool with label="FP". When they confirm
635
+ "that's correct", call with label="TP". Use the query or pattern as
636
+ the 'query' argument and include any user feedback in the 'note'.
637
+
638
+ Args:
639
+ tool_name: Name of the tool being evaluated (e.g., "trace_column_lineage")
640
+ query: The query or pattern that was evaluated
641
+ label: Feedback label: "TP" (true positive) or "FP" (false positive)
642
+ note: Optional user note (truncated to 500 chars)
643
+
644
+ Returns:
645
+ Dict with status: "recorded" or "skipped"
646
+
647
+ Raises:
648
+ ValueError: If label is not "TP" or "FP"
649
+ """
650
+ global _metrics
651
+
652
+ if label not in ("TP", "FP"):
653
+ raise ValueError(f"Invalid label: {label}. Must be 'TP' or 'FP'.")
654
+
655
+ if _metrics is not None:
656
+ try:
657
+ _metrics.record_feedback(tool_name, query, label, note)
658
+ return {"status": "recorded"}
659
+ except Exception as exc:
660
+ logger.warning(f"Failed to record feedback: {exc}")
661
+ return {"status": "skipped"}
662
+ else:
663
+ return {"status": "skipped"}
@@ -0,0 +1,6 @@
1
+ """Utility modules for sqlcg."""
2
+
3
+ from sqlcg.utils.hashing import hash_sql
4
+ from sqlcg.utils.logging import getLogger
5
+
6
+ __all__ = ["getLogger", "hash_sql"]
sqlcg/utils/hashing.py ADDED
@@ -0,0 +1,18 @@
1
+ """SQL hashing utilities."""
2
+
3
+ import hashlib
4
+
5
+
6
+ def hash_sql(sql: str) -> str:
7
+ """Generate a SHA-256 hash of SQL content.
8
+
9
+ The SQL is normalized by stripping leading and trailing whitespace before hashing.
10
+
11
+ Args:
12
+ sql: SQL statement string
13
+
14
+ Returns:
15
+ SHA-256 hex digest of the normalized SQL bytes
16
+ """
17
+ normalized_sql = sql.strip()
18
+ return hashlib.sha256(normalized_sql.encode("utf-8")).hexdigest()